题目描述:
给出一个 \(n\) 个点 \(m\) 条边的 \(DAG\) 和参数 \(k\)。
定义一条经过 \(l\) 条边的路径的权值为 \(l^k\).
对于 \(i = 1…n\), 求出所有 \(1\) 到 \(i\) 的路径的权值之和, 对 \(998244353\) 取模.
对于前 \(20\)% 的数据, \(n ≤ 2000\),\(m ≤ 5000\);
对于另 \(10\)% 的数据, \(k = 1\);
对于另 \(20\)% 的数据, \(k ≤ 30\);
对于 \(100\)% 的数据, \(1 ≤ n ≤ 100000\),\(1 ≤ m ≤ 200000\),\(0 ≤ k ≤ 500\),
保证从 \(1\) 出发可以到达每个点, 可能会有重边。
题目解法:
前\(50\)%直接二项式定理,时间复杂度为\(O(nk^2)\)。
对于每一个点 \(i\),题目需要求的是:
\(Ans_i = \sum len(i)^k\)
根据组合数学,有:
\[n^m = \sum_{j=1}^{min(n,m)}S_m^j*\prod_{n-j+1}^n =\sum_{j=1}^nS_m^j*{j!}C_n^j\]
其中\(S_n^m\)为斯特林数,具体含义见这里。
其实就是先选择\(j\)个盒子,然后把\(m\)个小球放到这\(j\)个盒子中。
那么带入本题中:
\(Ans_i = \sum len(i)^k = \sum\sum_{j=1}^{min(k,len(i))}S_k^j*{j!}*C_{len(i)}^j\)
然后是关键的一步:
\[\sum \sum_{j=1}^{k}S_k^j*{j!}*C_{len(i)}^j = \sum_{j=1}^kS^j_k*{j!} \sum C_{len(i)}^j\]
注意到这里把组合数\(C_{len(i)}^j\)提了出来,那么其实就是分离了变量\(len(i)\)
这个式子的前面部分中的\(S_k^j*{j!}\)可以直接算,所以只要处理组合数即可。
令\(F_{u,j} = \sum C_{len(u)}^j\)。
那么考虑由\(u\)转移到\(v\),那么即为:\(C_{len(i)}^j\) 变为 \(C_{len(i)+1}^j\)
由组合数公式可以知道:\(C_n^m = C_{n-1}^{m-1}+C_{n-1}^{m}\)
所以\(F_{i,j}\)的转移为:$ F_{v,j} = F_{u,j-1} + F_{u,j}。$
在\(TopSort\)时跑这个\(DP\),做出\(F_{i,j}\),然后带入上面的公式中即可计算\(Ans_i\)了。
总的时间复杂度为\(O(nk)\),可以跑过所有数据点。
实现代码
#include<bits/stdc++.h>
#define IL inline
#define ll long long
#define RG register
#define mod 998244353
using namespace std;
IL int gi(){
RG int date = 0, m = 1; RG char ch = 0;
while(ch!='-'&&(ch<'0'||ch>'9'))ch = getchar();
if(ch == '-'){m = -1; ch = getchar();}
while(ch>='0'&&ch<='9'){date=date*10+ch-'0';ch=getchar();}
return date*m;
}
int n,m,k,init[100005];
ll f[100005][505] , g[505][505] , fac , ans[100005];
struct Road{int to,next;}t[300005]; int head[100005],cnt;
queue<int>Q;
int main(){
n = gi(); m = gi(); k = gi();
for(int i = 1; i <= m; i++){
int u = gi() , v = gi();
t[++cnt] = (Road){v,head[u]}; head[u] = cnt;
init[v] ++;
}
while(!Q.empty())Q.pop();
for(int i = 1; i <= n; i ++)if(!init[i]){f[i][0] = 1; Q.push(i);}
g[0][0] = 1; //g[i][j] 把i个有异球 放到 j个无异盒 里。
for(int i = 1; i <= k; i ++)
for(int j = 0; j <= i; j ++){
g[i][j] = g[i][j] + 1ll*g[i-1][j]*j%mod;
if(j)g[i][j] = (g[i][j] + g[i-1][j-1])%mod;
}
while(!Q.empty()){
RG int u = Q.front(); Q.pop();
for(int i = head[u]; i; i = t[i].next){
int v = t[i].to;
init[v] --; if(!init[v])Q.push(v);
for(int j = 0; j <= k; j ++){
if(j)f[v][j] = (f[v][j] + f[u][j-1])%mod;
f[v][j] = (f[v][j] + f[u][j]) %mod;
}
}
}
for(int i = 1; i <= n; i ++){
fac = 1; ans[i] = 0;
for(ll j = 1; j <= k; j ++){
fac = 1ll*fac*j%mod; if(fac>=mod)fac%=mod;
ans[i] = (ans[i] + 1ll*fac*g[k][j]%mod*f[i][j]) %mod;
}
}
for(int i = 1; i <= n; i ++)printf("%lld\n",ans[i]);
return 0;
}