UOJ 449 【集训队作业2018】喂鸽子 【生成函数,min-max容斥】

时间:2023-03-09 13:28:27
UOJ 449 【集训队作业2018】喂鸽子 【生成函数,min-max容斥】

这是第100篇博客,所以肯定是要水过去的。

首先看到这种形式的东西首先min-max容斥一波,设\(f_{c,s}\)表示在\(c\)只咕咕中,经过\(s\)秒之后并没有喂饱任何一只的概率。

\[\begin{aligned}
Ans&=\sum_{i=1}^n(-1)^{i-1}\binom{n}{i}ans_i \\
ans_c&=\sum_{i\ge 1}\sum_{s=0}^i\binom{i}{s}(\frac{n-c}{n})^{i-s}(\frac{c}{n})^sf_{c,s} \\
&=\sum_{s=0}^{c(k-1)}f_{c,s}(\frac{c}{n})^s\sum_{i\ge 0}\binom{i+s}{s}(\frac{n-c}{n})^i \\
&=\frac{n}{c}\sum_{s=0}^{c(k-1)}f_{c,s}
\end{aligned}
\]

现在要求\(f_{c,s}\),转换为方案数之后求它的EGF\(f_c(x)\)。则\(f_c(x)=(\sum_{i=0}^k\frac{x^i}{i!})^c\).用NTT直接做就可以做到\(O(n^2k\log k)\),但是还有更快的方式。

\[\begin{aligned}
f(x)&=1+x+\frac{x^2}{2!}+\ldots+\frac{x^k}{k!} \\
f'(x)&=f(x)-\frac{x^k}{k!} \\
f'_c(x)&=cf_{c-1}(x)(f(x)-\frac{x^k}{k!}) \\
&=c(f_c(x)-\frac{x^k}{k!}f_{c-1}(x)) \\
f_{c,s+1}&=\frac{c}{s+1}(f_{c,s}-\frac{1}{k!}f_{c-1,s-k})
\end{aligned}
\]

(按照对应系数列出递推式)

然后把\(f_{c,s}\)乘上\(\frac{1}{c^s}\)就可以得到概率了。

#include<bits/stdc++.h>
#define Rint register int
using namespace std;
const int N = 55555, mod = 998244353;
namespace MY {
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define For(i,x,y) for (int i=(x);i<=(y);i++)
#define Rof(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int i=head[x];i;i=edge[i].nxt)
#define templ template<typename T>
typedef long long LL;
inline int rand(){
static int seed = 20050915;
return (((seed * 19260817ll % 2147483647) + 1000000007) % 2147483647) ^ 998244353;
}
templ inline bool chkmin(T &a, T b){return a < b ? a = b, 1 : 0;}
templ inline bool chkmax(T &a, T b){return a > b ? a = b, 1 : 0;}
templ inline void read(T &x){
int ch = getchar();
bool flag = false;
double d = 1;
while((ch < '0' || ch > '9') && ch != '-') ch = getchar();
if(ch == '-'){
flag = true;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
x = x * 10 + ch - '0';
ch = getchar();
}
if(ch == '.'){
ch = getchar();
while(ch >= '0' && ch <= '9'){
d *= 0.1;
x += d * (ch - '0');
ch = getchar();
}
}
if(flag) x = -x;
}
inline void upd(int &a, int b, int p = mod){a += b; if(a >= p) a -= p;}
inline int add(int a, int b, int p = mod){int res = a + b; if(res >= p) res -= p; return res;}
inline int dec(int a, int b, int p = mod){int res = a - b; if(res < 0) res += p; return res;}
inline int kasumi(int a, int b, int p = mod){
int res = 1;
while(b){
if(b & 1) res = (LL) res * a % mod;
a = (LL) a * a % mod;
b >>= 1;
}
return res;
}
}
using namespace MY;
int n, k, f[53][N], fac[N], inv[N], res;
inline void init(int m){
fac[0] = 1;
for(Rint i = 1;i <= m;i ++) fac[i] = (LL) fac[i - 1] * i % mod;
inv[m] = kasumi(fac[m], mod - 2);
for(Rint i = m;i;i --) inv[i - 1] = (LL) inv[i] * i % mod;
}
inline int C(int n, int m){return (LL) fac[n] * inv[m] % mod * inv[n - m] % mod;}
int main(){
scanf("%d%d", &n, &k); init(n * k + 5);
f[0][0] = 1;
For(i, 0, k - 1) f[1][i] = inv[i];
For(i, 2, n){
f[i][0] = 1;
For(j, 0, i * (k - 1) - 1)
f[i][j + 1] = (LL) i * inv[j + 1] % mod * fac[j] % mod *
dec(f[i][j], (j >= k - 1) ? ((LL) inv[k - 1] * f[i - 1][j - k + 1] % mod) : 0) % mod;
}
For(i, 1, n){
int tmp = 0, ttt = 1;
For(j, 0, i * (k - 1)){
upd(tmp, (LL) f[i][j] * fac[j] % mod * ttt % mod);
ttt = (LL) ttt * inv[i] % mod * fac[i - 1] % mod;
}
tmp = (LL) tmp * inv[i] % mod * fac[i - 1] % mod * C(n, i) % mod;
if(i & 1) upd(res, tmp); else upd(res, mod - tmp);
}
printf("%d", (LL) res * n % mod);
}

(实际上是为了熟悉一下头文件来写的代码)