NOI模拟(5.11) BJOID2T3 治疗之雨 (bzoj5292)

时间:2022-12-16 17:55:03

治疗之雨

题目背景:

5.11 模拟 BJOI2018D2T3

分析:期望DP + 高斯消元优化

 

我对这道题真的是有一千句喵喵喵,因为一句特判没写,10010分,内心崩溃。直接说正解吧,定义dp[i]表示还剩i点血的期望步数,f[i]表示在k次攻击中被打中i次的概率。显然:

NOI模拟(5.11) BJOID2T3 治疗之雨 (bzoj5292)

显然,对于每一个dp[i]的式子(i < n)都包含了dp[i + 1]所以直接递推是行不通的,要直接做的话,只能考虑高斯消元,但是复杂度是T * n3只有最多不到70分,考虑如何优化,我们发现,对于关于1 ~ n - 1的等式,对于第i个等式,我们可以更改一下,变成关于dp[i + 1],这样就可以变成一个递推式的形式,那么对于dp[i + 1] (1 <=i < n)都可以表示成一个由dp[0 ~ i]表示的式子,经过一些化简,我们可以把dp[2 ~ n]全部表示成关于dp[1]的一个一次函数,然后我们只需要用最后一个关于dp[n]的式子,将dp[n]表示出来,将两个dp[n]联立求解dp[1],然后获得dp[p]就可以了。对于第1 ~ n - 1个式子:

NOI模拟(5.11) BJOID2T3 治疗之雨 (bzoj5292)

总复杂度O(T * n2),可以通过,dp没有什么非常需要注意的细节,但是一定要打齐特判······

 

Source:


/*
    created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
#include <ctime>
#include <bitset>
 
inline char read() {
    static const int IN_LEN = 1024 * 1024;
    static char buf[IN_LEN], *s, *t;
    if (s == t) {
        t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
        if (s == t) return -1;
    }
    return *s++;
}
 
// /*
template<class T>
inline void R(T &x) {
    static char c;
    static bool iosig;
    for (c = read(), iosig = false; !isdigit(c); c = read()) {
        if (c == -1) return ;
        if (c == '-') iosig = true; 
    }
    for (x = 0; isdigit(c); c = read()) 
        x = ((x << 2) + x << 1) + (c ^ '0');
    if (iosig) x = -x;
}
//*/

const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN];
char *oh = obuf;
inline void write_char(char c) {
	if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
	*oh++ = c;
}


template<class T>
inline void W(T x) {
	static int buf[30], cnt;
	if (x == 0) write_char('0');
	else {
		if (x < 0) write_char('-'), x = -x;
		for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
		while (cnt) write_char(buf[cnt--]);
	}
}

inline void flush() {
	fwrite(obuf, 1, oh - obuf, stdout), oh = obuf;
}
 
/*
template<class T>
inline void R(T &x) {
    static char c;
    static bool iosig;
    for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
        if (c == '-') iosig = true; 
    for (x = 0; isdigit(c); c = getchar()) 
        x = ((x << 2) + x << 1) + (c ^ '0');
    if (iosig) x = -x;
}
//*/

const int MAXN = 1500 + 10;
const int mod = 1000000000 + 7;

int n, p, m, k, t;
long long f[MAXN];

inline int add(int a, int b) {
    a += b;
    return (a >= mod) ? (a -= mod) : (a);
}

struct data {
    long long a, b;
    data() {}
    data(long long a, long long b) : a(a), b(b) {}

    inline data operator + (const data &c) const {
        return data(add(a, c.a), add(b, c.b));
    }

    inline data operator - (const data &c) const {
        return data(add(a, mod - c.a), add(b, mod - c.b));
    }

    inline data operator * (const long long &c) const {
        return data(a * c % mod, b * c % mod);
    }

    inline data operator + (const long long &c) const {
        return data(a, add(b, c));
    }

    inline data operator - (const long long &c) const {
        return data(a, add(b, mod - c));
    }
} dp[MAXN];

inline long long mod_pow(long long a, long long b) {
    long long ans = 1;
    for (; b; b >>= 1, a = a * a % mod)
        if (b & 1) ans = ans * a % mod;
    return ans;
}

inline void solve() {
    /*
    2 1 0 2
    */
    R(n), R(p), R(m), R(k);
    if (p == 0) {
        puts("0");
        return ;
    }
    if (k == 0) {
        puts("-1");
        return ;
    }
    if (m == 0) {
        if (k == 0) {
            puts("-1");
            return ;
        }
        if (k == 1) {
            if (n != 1) puts("-1");
            else puts("1");
            return ;
        }
        if (p == n) std::cout << std::ceil((double)(p - k) / (k - 1)) + 1 << '\n';
        else std::cout << std::ceil((double)(p - k + 1) / (k - 1)) + 1 << '\n';
        return ;
    }
    long long c = 1, x = 1, inv = mod_pow(m + 1, mod - 2);
    long long inv_m = mod_pow(inv * m % mod, mod - 2);
    long long y = mod_pow(inv * m % mod, k);
    for (int i = 0, end = std::min(n, k); i <= end; ++i) {
        f[i] = c * x % mod * y % mod;
        c = c * (k - i) % mod * mod_pow(i + 1, mod - 2) % mod;
        x = x * inv % mod, y = y * inv_m % mod;
    }
    long long inv0 = mod_pow(f[0], mod - 2);
    dp[0] = data(0, 0), dp[1] = data(1, 0);
    for (int i = 1; i < n; ++i) {
        dp[i + 1] = dp[i] * (m + 1) - m - 1;
        data t = data(0, 0);
        for (int j = 0, end = std::min(k, i); j <= end; ++j) 
            t = t + dp[i - j] * f[j];
        dp[i + 1] = dp[i + 1] - t * m, t = data(0, 0); 
        for (int j = 1, end = std::min(k, i + 1); j <= end; ++j)
            t = t + dp[i - j + 1] * f[j];
        dp[i + 1] = dp[i + 1] - t, dp[i + 1] = dp[i + 1] * inv0;
    }
    data ret = data(0, 1);
    for (int j = 1, end = std::min(k, n); j <= end; ++j) 
        ret = ret + dp[n - j] * f[j];
    ret = ret * mod_pow((1 - f[0] + mod) % mod, mod - 2);
    ret = ret - dp[n];
    if (ret.a == 0 && ret.b != 0) puts("-1");
    else {
        long long t = (long long)(mod - ret.b) * mod_pow(ret.a, mod - 2) % mod;
        std::cout << ((dp[p].a * t % mod + dp[p].b) % mod + mod) % mod << '\n';
    }
}

int main() {
    // freopen("heal.in", "r", stdin);
    // freopen("heal.out", "w", stdout);
    R(t);
    while (t--) solve();
    return 0;
}