Luogu4491 [HAOI2018]染色 【容斥原理】【NTT】

时间:2024-07-28 21:06:50

题目分析:

一开始以为是直接用指数型生成函数,后来发现复杂度不对,想了一下容斥的方法。

对于有$i$种颜色恰好出现$s$次的情况,利用容斥原理得到方案数为

$$\binom{m}{i}\frac{P_{is}^{n}}{(s!)^i}(\sum_{j=0}^{m-i}(-1)^j\binom{m-i}{j}\frac{P_{js}^{n-is}}{(s!)^j}(m-i-j)^{n-is-js})$$

值得注意的是$n-is-js<0$的时候,后面的式子直接等于$0$,特判一下就行了。

那么答案就等于

$$\sum_{i=0}^{m}w_i\binom{m}{i}\frac{P_{is}^{n}}{(s!)^i}(\sum_{j=0}^{m-i}(-1)^j\binom{m-i}{j}\frac{P_{js}^{n-is}}{(s!)^j}(m-i-j)^{n-is-js})$$

式子看着很长,但其实没啥味道,把组合数和排列数展开,常数项提出来,约分,可以得到上面的式子等价于

$$(n!)*(m!)*\sum_{i=0}^{m}\frac{w_i}{(i!)(s!)^i}(\sum_{j=0}^{m-i}\frac{(-1)^j(m-i-j)^{n-is-js}}{(m-i-j)!j!(n-is-js)!(s!)^j})$$

对于后面的那个求和,使用肉眼观察法,会发现是个关于$j$和$m-i-j$的卷积。因为$m-i-j$的值确定了就意味着$n-is-js$的值也确定了。所以NTT搞出来

时间复杂度$O(nlogn)$

代码:

 #include<bits/stdc++.h>
using namespace std; const int maxn = ;
const int mod = ;
const int gg = ; int n,m,s;
int w[maxn]; int fac[],A[maxn*],B[maxn*]; int fast_pow(int now,int pw){
int ans = ,dt = now,bit = ;
while(bit <= pw){
if(bit &pw){ans = 1ll*ans*dt%mod;}
dt = 1ll*dt*dt%mod; bit<<=;
}
return ans;
} void buildfunc(){
fac[] = ;
for(int i=;i<=max(n,m);i++) fac[i] = 1ll*fac[i-]*i%mod;
for(int i=;i<=m;i++){
A[i] = 1ll*fast_pow(fac[s],i)*fac[i]%mod;
A[i] = fast_pow(A[i],mod-);
if(i&) A[i] = 1ll*(mod-)*A[i]%mod;
}
for(int i=;i<=m;i++){
int z = m-i;
if(n-z*s < ) {B[i] = ;continue;}
int rem = n-z*s;
B[i] = 1ll*fac[i]*fac[rem]%mod;
B[i] = fast_pow(B[i],mod-);
B[i] = 1ll*B[i]*fast_pow(i,rem)%mod;
}
} int ord[maxn*]; void NTT(int *d,int len,int dr){
for(int i=;i<len;i++) if(ord[i] < i) swap(d[i],d[ord[i]]);
for(int i=;i<len;i<<=){
int w = fast_pow(gg,(mod-)/(*i));
if(dr == -) w = fast_pow(w,mod-);
for(int j=;j<len;j+=(i<<)){
for(int k=,wn=;k<i;k++,wn = 1ll*wn*w%mod){
int x = d[j+k],y = 1ll*wn*d[j+k+i]%mod;
d[j+k] = (x+y)%mod;
d[j+k+i] = (x-y+mod)%mod;
}
}
}
if(dr == -){
int iv = fast_pow(len,mod-);
for(int i=;i<len;i++){d[i] = 1ll*d[i]*iv%mod;}
}
} void work(){
buildfunc();
/*int reans = 0;
for(int i=0;i<=m;i++){
int z = 1ll*w[i]*fast_pow(1ll*fac[i]*fast_pow(fac[s],i)%mod,mod-2)%mod;
int np = 0,kp = 0;
for(int j=0;j<m-i;j++){
if(n-i*s-j*s < 0) continue;
int mp = 0;
mp = 1ll*fac[m-i-j]*fac[j]%mod*fac[n-i*s-j*s]%mod*fast_pow(fac[s],j)%mod;
mp = 1ll*fast_pow(mp,mod-2)*fast_pow(m-i-j,n-i*s-j*s)%mod;
if(j & 1) mp = 1ll*(mod-1)*mp%mod;
kp += 1ll*A[j]*B[m-i-j]%mod;
kp %= mod;
np += mp;
np %= mod;
}
reans += 1ll*z*np%mod;
reans %= mod;
}
reans = 1ll*reans*fac[n]%mod*fac[m]%mod;
printf("%d\n",reans);return;*/ int hk = ,pi = ; while(hk <= m+m) hk*=,pi++;
for(int i=;i<hk;i++) ord[i] = (ord[i>>]>>) + ((i&)<<(pi-));
NTT(A,hk,); NTT(B,hk,);
for(int i=;i<hk;i++) A[i] = 1ll*A[i]*B[i]%mod;
NTT(A,hk,-);
int ans = ;
for(int i=;i<=m;i++){
int z = 1ll*fac[m-i]*fast_pow(fac[s],m-i)%mod;
z = 1ll*fast_pow(z,mod-)*w[m-i]%mod;
ans += 1ll*z*A[i]%mod;
ans %= mod;
}
ans = 1ll*ans*fac[n]%mod*fac[m]%mod;
printf("%d\n",ans);
} int main(){
scanf("%d%d%d",&n,&m,&s);
for(int i=;i<=m;i++) scanf("%d",&w[i]);
work();
return ;
}