LOJ #2538. 「PKUWC 2018」Slay the Spire (期望dp)

时间:2024-11-25 15:35:14

Update on 1.5

学了 zhou888 的写法,真是又短又快。

并且空间是 \(O(n)\) 的,速度十分优秀。

题意

LOJ #2538. 「PKUWC 2018」Slay the Spire

题解

首先我们考虑拿到一副牌如何打是最优的,不难发现是将强化牌从大到小能打就打,最后再从大到小打攻击牌 。

为什么呢 ?

证明(简单说明) : 如果不是这样 , 那么我们就是有强化牌没有用 , 且攻击牌超过两张 .

我们考虑把最小的那张攻击牌拿出来 , 然后放入一张强化牌 .

\(\because~w_i \ge 2\) 且 最小那张攻击牌的攻击力 \(a_{\min}\) 不大于所有攻击牌的总和 \(a_{sum}\) 的一半

\(\therefore\) 修改后造成的伤害绝对不比前面少 . 得证.

我们只要枚举上下分别用了多少张牌 , 假设 加强 用了 \(a\) 张 , 攻击 用了 \(b\) 张 . \((a + b = m)\)

那么我们只要分两种情况考虑了 :

  1. \(a < k:\) 那么我们加强可以全用完 , 攻击用前 \(k - a\) 大的 ;
  2. \(a \ge k:\) 这个加强用前 \(k - 1\) 大的 , 攻击用一张最大的 .

令 \(f_i\) 为选 \(i\) 张强化牌能得到的最优倍率之和,显然强化牌我们从大到小取是最优的。

假设当前取到第 \(j\) 张牌。

那么有如下转移:

\[f_i =
\begin{cases}
(f_i + f_{i - 1}) \times a[j] &i < k\\
f_i + f_{i - 1} &i \ge k
\end{cases}
\]

上面那种情况是还能用强化牌,下面已经不能加新的强化牌了,所以不乘上倍率。(注意 \(f_0 = 1\) )

同样我们设 \(g_i\) 为选 \(i\) 张攻击牌能得到的最优攻击之和,此处我们需要从小到大排序。

有如下转移:

\[g_i = g_i + \displaystyle {j - 1 \choose i - 1} \times a[j] +
\begin{cases}
0 &\le m - (k - 1)\\
g_{i - 1} & >m - (k - 1)
\end{cases}
\]

考虑这张牌我们先放进来,不难发现对于所有 \(i \le m - (k - 1)\) 也就是只能打一张的,我们只统计了这张打的贡献。

如果能打很多张,这样转移的话就能保证我们尽量取的是靠后的那些元素。

最后答案直接就是 \(\displaystyle \sum_{i = 0}^{m} f_{i} g_{m-i}\) 。

总结

需要啥就设啥,想清楚情况再 \(dp\) 。

代码

#include <bits/stdc++.h>

#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl using namespace std; template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; } inline int read() {
int x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
} void File() {
#ifdef zjp_shadow
freopen ("2538.in", "r", stdin);
freopen ("2538.out", "w", stdout);
#endif
} const int N = 3e3 + 1e2, Mod = 998244353; inline int fpm(int x, int power) {
int res = 1;
for (; power; power >>= 1, x = 1ll * x * x % Mod)
if (power & 1) res = 1ll * res * x % Mod;
return res;
} int fac[N], ifac[N]; void Math_Init(int maxn) {
fac[0] = ifac[0] = 1;
For (i, 1, maxn) fac[i] = 1ll * fac[i - 1] * i % Mod;
ifac[maxn] = fpm(fac[maxn], Mod - 2);
Fordown (i, maxn - 1, 1) ifac[i] = ifac[i + 1] * (i + 1ll) % Mod;
} inline int Comb(int n, int m) {
if (n < 0 || m < 0 || n < m) return 0;
return 1ll * fac[n] * ifac[m] % Mod * ifac[n - m] % Mod;
} int n, m, k, a[N], f[N], g[N]; int main () { File(); Math_Init(3000); for (int cases = read(); cases; -- cases) { n = read(); m = read(); k = read();
For (i, 1, n) a[i] = read(); For (i, 1, max(n, m)) f[i] = g[i] = 0; sort(a + 1, a + n + 1, greater<int>()); f[0] = 1;
For (i, 1, n) Fordown (j, min(i, m), 0)
if (j <= k - 1) f[j] = (f[j] + 1ll * f[j - 1] * a[i]) % Mod;
else f[j] = (f[j] + f[j - 1]) % Mod; For (i, 1, n) a[i] = read();
sort(a + 1, a + n + 1);
For (i, 1, n) Fordown (j, min(i, m), 0)
if (j <= m - (k - 1))
g[j] = (g[j] + 1ll * Comb(i - 1, j - 1) * a[i]) % Mod;
else
g[j] = (g[j] + g[j - 1] + 1ll * Comb(i - 1, j - 1) * a[i]) % Mod; int ans = 0;
For (i, 0, m)
ans = (ans + 1ll * f[i] * g[m - i]) % Mod;
printf ("%d\n", ans); } return 0; }