治疗之雨
题目背景:
5.11 模拟 BJOI2018D2T3
分析:期望DP + 高斯消元优化
我对这道题真的是有一千句喵喵喵,因为一句特判没写,100变10分,内心崩溃。直接说正解吧,定义dp[i]表示还剩i点血的期望步数,f[i]表示在k次攻击中被打中i次的概率。显然:
显然,对于每一个dp[i]的式子(i < n)都包含了dp[i + 1]所以直接递推是行不通的,要直接做的话,只能考虑高斯消元,但是复杂度是T * n3只有最多不到70分,考虑如何优化,我们发现,对于关于1 ~ n - 1的等式,对于第i个等式,我们可以更改一下,变成关于dp[i + 1],这样就可以变成一个递推式的形式,那么对于dp[i + 1] (1 <=i < n)都可以表示成一个由dp[0 ~ i]表示的式子,经过一些化简,我们可以把dp[2 ~ n]全部表示成关于dp[1]的一个一次函数,然后我们只需要用最后一个关于dp[n]的式子,将dp[n]表示出来,将两个dp[n]联立求解dp[1],然后获得dp[p]就可以了。对于第1 ~ n - 1个式子:
总复杂度O(T * n2),可以通过,dp没有什么非常需要注意的细节,但是一定要打齐特判······
Source:
/* created by scarlyw */ #include <cstdio> #include <string> #include <algorithm> #include <cstring> #include <iostream> #include <cmath> #include <cctype> #include <vector> #include <set> #include <queue> #include <ctime> #include <bitset> inline char read() { static const int IN_LEN = 1024 * 1024; static char buf[IN_LEN], *s, *t; if (s == t) { t = (s = buf) + fread(buf, 1, IN_LEN, stdin); if (s == t) return -1; } return *s++; } // /* template<class T> inline void R(T &x) { static char c; static bool iosig; for (c = read(), iosig = false; !isdigit(c); c = read()) { if (c == -1) return ; if (c == '-') iosig = true; } for (x = 0; isdigit(c); c = read()) x = ((x << 2) + x << 1) + (c ^ '0'); if (iosig) x = -x; } //*/ const int OUT_LEN = 1024 * 1024; char obuf[OUT_LEN]; char *oh = obuf; inline void write_char(char c) { if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf; *oh++ = c; } template<class T> inline void W(T x) { static int buf[30], cnt; if (x == 0) write_char('0'); else { if (x < 0) write_char('-'), x = -x; for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48; while (cnt) write_char(buf[cnt--]); } } inline void flush() { fwrite(obuf, 1, oh - obuf, stdout), oh = obuf; } /* template<class T> inline void R(T &x) { static char c; static bool iosig; for (c = getchar(), iosig = false; !isdigit(c); c = getchar()) if (c == '-') iosig = true; for (x = 0; isdigit(c); c = getchar()) x = ((x << 2) + x << 1) + (c ^ '0'); if (iosig) x = -x; } //*/ const int MAXN = 1500 + 10; const int mod = 1000000000 + 7; int n, p, m, k, t; long long f[MAXN]; inline int add(int a, int b) { a += b; return (a >= mod) ? (a -= mod) : (a); } struct data { long long a, b; data() {} data(long long a, long long b) : a(a), b(b) {} inline data operator + (const data &c) const { return data(add(a, c.a), add(b, c.b)); } inline data operator - (const data &c) const { return data(add(a, mod - c.a), add(b, mod - c.b)); } inline data operator * (const long long &c) const { return data(a * c % mod, b * c % mod); } inline data operator + (const long long &c) const { return data(a, add(b, c)); } inline data operator - (const long long &c) const { return data(a, add(b, mod - c)); } } dp[MAXN]; inline long long mod_pow(long long a, long long b) { long long ans = 1; for (; b; b >>= 1, a = a * a % mod) if (b & 1) ans = ans * a % mod; return ans; } inline void solve() { /* 2 1 0 2 */ R(n), R(p), R(m), R(k); if (p == 0) { puts("0"); return ; } if (k == 0) { puts("-1"); return ; } if (m == 0) { if (k == 0) { puts("-1"); return ; } if (k == 1) { if (n != 1) puts("-1"); else puts("1"); return ; } if (p == n) std::cout << std::ceil((double)(p - k) / (k - 1)) + 1 << '\n'; else std::cout << std::ceil((double)(p - k + 1) / (k - 1)) + 1 << '\n'; return ; } long long c = 1, x = 1, inv = mod_pow(m + 1, mod - 2); long long inv_m = mod_pow(inv * m % mod, mod - 2); long long y = mod_pow(inv * m % mod, k); for (int i = 0, end = std::min(n, k); i <= end; ++i) { f[i] = c * x % mod * y % mod; c = c * (k - i) % mod * mod_pow(i + 1, mod - 2) % mod; x = x * inv % mod, y = y * inv_m % mod; } long long inv0 = mod_pow(f[0], mod - 2); dp[0] = data(0, 0), dp[1] = data(1, 0); for (int i = 1; i < n; ++i) { dp[i + 1] = dp[i] * (m + 1) - m - 1; data t = data(0, 0); for (int j = 0, end = std::min(k, i); j <= end; ++j) t = t + dp[i - j] * f[j]; dp[i + 1] = dp[i + 1] - t * m, t = data(0, 0); for (int j = 1, end = std::min(k, i + 1); j <= end; ++j) t = t + dp[i - j + 1] * f[j]; dp[i + 1] = dp[i + 1] - t, dp[i + 1] = dp[i + 1] * inv0; } data ret = data(0, 1); for (int j = 1, end = std::min(k, n); j <= end; ++j) ret = ret + dp[n - j] * f[j]; ret = ret * mod_pow((1 - f[0] + mod) % mod, mod - 2); ret = ret - dp[n]; if (ret.a == 0 && ret.b != 0) puts("-1"); else { long long t = (long long)(mod - ret.b) * mod_pow(ret.a, mod - 2) % mod; std::cout << ((dp[p].a * t % mod + dp[p].b) % mod + mod) % mod << '\n'; } } int main() { // freopen("heal.in", "r", stdin); // freopen("heal.out", "w", stdout); R(t); while (t--) solve(); return 0; }