LOJ #2541. 「PKUWC 2018」猎人杀(容斥 , 期望dp , NTT优化)

时间:2023-03-08 17:10:57

题意

LOJ #2541. 「PKUWC 2018」猎人杀

题解

一道及其巧妙的题 , 参考了一下这位大佬的博客 ...

令 \(\displaystyle A = \sum_{i=1}^{n} w_i\) , \(B\) 是已死猎人的 \(w_i\) 的总和 , \(P_i\) 是 \(i\) 当前要被杀死的概率 ... (抄博客咯)

不难有 \(\displaystyle P_i = \frac{w_i}{A-B} \tag{1}\)

如果 不考虑猎人死没死 , 都能被当做目标 qwq (鞭尸) 也就是算进去概率 ...

那么也会有 \(\displaystyle P_i = \frac{B}{A} P_i + \frac{w_i}{A} \tag{2}\)

这个是为什么呢 ... 因为假设打到它 那么就是死了 , 如果打到了死的目标 , 那么这把又会重新来过 , 但他死的概率还是没变 ... (类似于有一些期望题中的高斯消元)

发现 \((2)\) 移项后就得到了 \((1) ~ !!!!\) 这个就很巧妙啦 ~ 也就是说 按第二种来算也可以算出正确的结果 .

然后我们考虑容斥 , 枚举在 \(1\) 号后面死的人 , 然后令 \(S\) 为枚举到的人的 \(w_i\) 之和 , 人数为 \(t\) . 那么意味着 \(1\) 号的位置至多是 \(n - t\) 位 .

不难发现答案就是 $$\displaystyle ans = (-1)^t \sum_{i=0}^{\infty} (1-\frac{S+w_1}{A})^i\frac{w_1}{A}$$ .

这个代入前面的式子就可以得到了 . 这里虽然算的是无限概率 , 但是一个收敛的无限等比数列 , 我们用常规的作差法就可算出来了 .

令\(\displaystyle T = \sum_{i=0}^{\infty} (1-\frac{S+w_1}{A})^i \tag{3}\)

则又有\(\displaystyle (1-\frac{S+w_1}{A})T=\sum_{i=1}^{\infty} (1-\frac{S+w_1}{A})^i \tag{4}\)

让 \((3) - (4)\) 就可以得到

\[\displaystyle \frac{S+w_1}{A} T = (1-\frac{S+w_1}{A})^0=1
\]

\[\displaystyle \therefore T=\frac{A}{S+w_1}
\]

\[\displaystyle \therefore ans=(-1)^t\frac{w_1}{S+w_1}
\]

啊 多么美妙的一个式子啊 qwq

然后直接算每个 \(S\) 前面的系数就行了 .

朴素 dp 就是 令 \(f_{i,j}\) 为到第 \(i\) 个总和为 \(j\) 的系数和 . 每次有两种决策 , 一种是选第 \(i\) 个 另一种是不选 .

那么很容易发现第一种加上去的时候奇偶性改变 , 乘上 \(-1\) 就行了 .

\[\displaystyle f_{i,j} \to f_{i+1,j} ; -f_{i,j} \to f_{i+1,j+w_i}
\]

不难发现这个是个多项式乘法 ...

又由于多项式乘法有交换律和结合律 , 直接每次挑 \(size\) 最小的两个合并就行了 .

令 \(\displaystyle q=\sum_{i=1}^{n} w_i\) 时间复杂度就是 \(O(q \log q \log n)\) 可以通过此题...

代码

#include <bits/stdc++.h>
#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
using namespace std; inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;} inline int read() {
int x = 0, fh = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
return x * fh;
} void File() {
#ifdef zjp_shadow
freopen ("2541.in", "r", stdin);
freopen ("2541.out", "w", stdout);
#endif
} typedef long long ll;
const int Mod = 998244353; ll fpm(ll x, int power) {
ll res = 1;
for (; power; power >>= 1, (x *= x) %= Mod)
if (power & 1) (res *= x) %= Mod;
return res;
} const int Maxn = (1 << 20) + 5;
struct Number_Theoretical_Transform {
int n, n1, n2, m;
ll pow3[Maxn], invpow3[Maxn]; inline void Init(int maxn) {
for (int i = 1; i <= maxn; i <<= 1) {
pow3[i] = fpm(3, (Mod - 1) / i);
invpow3[i] = fpm(pow3[i], Mod - 2);
}
} int rev[Maxn]; void NTT(ll P[], int opt) {
For (i, 0, n - 1) if (i < rev[i]) swap(P[i], P[rev[i]]);
for (int i = 2; i <= n; i <<= 1) {
int p = i / 2;
ll Wi = opt == 1 ? pow3[i] : invpow3[i];
for (int j = 0; j < n; j += i) {
ll x = 1;
for (int k = 0; k < p; ++ k, (x *= Wi) %= Mod) {
ll u = P[j + k], v = x * P[j + k + p] % Mod;
P[j + k] = (u + v) % Mod;
P[j + k + p] = (u - v + Mod) % Mod;
}
}
}
if (opt == -1) {
ll invn = fpm(n, Mod - 2);
For (i, 0, n - 1) (P[i] *= invn) %= Mod;
}
} ll A[Maxn], B[Maxn];
inline vector<int> Mult(vector<int> a, vector<int> b) {
n1 = (int)a.size() - 1; n2 = (int)b.size() - 1; m = n1 + n2;
For (i, 0, n1) A[i] = a[i]; For (i, 0, n2) B[i] = b[i]; int cnt = 0; for (n = 1; n <= m; n <<= 1) ++ cnt;
For (i, 1, n) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1)); For (i, n1 + 1, n) A[i] = 0; For (i, n2 + 1, n) B[i] = 0; NTT(A, 1); NTT(B, 1);
For (i, 0, n - 1) (A[i] *= B[i]) %= Mod;
NTT(A, - 1); vector<int> res; res.resize(m + 1); For (i, 0, m) res[i] = A[i];
return res;
}
} T; const int N = 100010;
int n, dp[2][N], w[N], cur, w1; struct Seq {
vector<int> V;
inline bool operator < (const Seq &rhs) const {
return (int)V.size() > (int)rhs.V.size();
}
} ; priority_queue<Seq, vector<Seq> > P; void Out(vector<int> A) {
For (i, 0, A.size() - 1)
printf ("%d%c", A[i], i == iend ? '\n' : ' ');
} int main () {
File(); T.Init(1 << 20); n = read() - 1; w1 = read();
int tot = 0;
For (i, 1, n) w[i] = read(), tot += w[i]; For (i, 1, n) {
vector<int> now;
now.resize(w[i] + 1);
now[0] = 1; now[w[i]] = Mod - 1;
P.push((Seq) {now});
} while (P.size() > 1) {
Seq a = P.top(); P.pop();
Seq b = P.top(); P.pop();
P.push((Seq) {T.Mult(a.V, b.V)});
} Seq res = P.top(); int ans = 0;
For (i, 0, res.V.size() - 1) {
ans += 1ll * res.V[i] * w1 % Mod * fpm(i + w1, Mod - 2) % Mod;
ans = (ans % Mod + Mod) % Mod;
} printf ("%d\n", ans); return 0;
}