CCF CSP 201703-5 引水入城 (100分)

时间:2021-05-29 18:01:24

DP 复杂度是O(N^2)的。最后一组数据T掉了,改天再改改吧。
UPD:过了最后一组,开了编译器优化,并且将数组的两个维度换了一下,利用了局部防存。

// 90 points
#pragma GCC optimize("Ofast")  
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")  
#include <cstdio>  
const int maxn=5000;  
const long long inf=0x7fffffffffff;  
int n,m,A,B,Q,X0,col[maxn][maxn],row[maxn][maxn];  
long long dp[maxn],v;  
int main()  
{  
    scanf("%d%d%d%d%d%d",&n,&m,&A,&B,&Q,&X0);  
    for (int i=0;i<n-1;++i)  
        for (int j=0;j<m;++j)  
            col[i][j]=X0=(1LL*A*X0+B)%Q;  
    for (int i=1;i<=n-2;++i)  
        for (int j=0;j<m-1;++j)  
            row[i][j]=X0=(1LL*A*X0+B)%Q;  
    for (int i=0;i<n-1;++i)  
        dp[i]=col[i][0];  
    for (int i=1;i<n-1;++i) {  
        v=dp[i-1]+row[i][0];  
        dp[i]=dp[i]<v?dp[i]:v;  
    }  
    for (int i=n-3;i>=0;--i) {  
        v=dp[i+1]+row[i+1][0];  
        dp[i]=dp[i]<v?dp[i]:v;  
    }  
    for (int i=1;i<m;++i) {  
        for (int j=0;j<n-1;++j)  
            dp[j]+=col[j][i];  
        for (int j=1;j<n-1;++j) {  
            v=dp[j-1]+row[j][i];  
            dp[j]=dp[j]<v?dp[j]:v;  
        }  
        for (int j=n-3;j>=0;--j) {  
            v=dp[j+1]+row[j+1][i];  
            dp[j]=dp[j]<v?dp[j]:v;  
        }  
    }  
    long long res=inf;  
    for (int i=0;i<n-1;++i)  
        res=res<dp[i]?res:dp[i];  
    return 0*printf("%I64d\n",res);  
}  


// 下面是 100 分的程序


// 100 points
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include <cstdio>
const int maxn=5000;
const long long inf=0x3f3f3f3f3f3f3f3fLL;
int n,m,A,B,Q,X0,col[maxn][maxn],row[maxn][maxn];
long long dp[maxn],v;
int main()
{
    scanf("%d%d%d%d%d%d",&n,&m,&A,&B,&Q,&X0);
    for (int i=0;i<n-1;++i)
        for (int j=0;j<m;++j)
            col[j][i]=X0=((long long)A*X0+B)%Q;
    for (int i=1;i<=n-2;++i)
        for (int j=0;j<m-1;++j)
            row[j][i]=X0=((long long)A*X0+B)%Q;
    for (int i=0;i<n-1;++i)
        dp[i]=col[0][i];
    for (int i=1;i<n-1;++i) {
        v=dp[i-1]+row[0][i];
        dp[i]=dp[i]<v?dp[i]:v;
    }
    for (int i=n-3;i>=0;--i) {
        v=dp[i+1]+row[0][i+1];
        dp[i]=dp[i]<v?dp[i]:v;
    }
    for (int i=1;i<m;++i) {
        for (int j=0;j<n-1;++j)
            dp[j]+=col[i][j];
        for (int j=1;j<n-1;++j) {
            v=dp[j-1]+row[i][j];
            dp[j]=dp[j]<v?dp[j]:v;
        }
        for (int j=n-3;j>=0;--j) {
            v=dp[j+1]+row[i][j+1];
            dp[j]=dp[j]<v?dp[j]:v;
        }
    }
    long long res=inf;
    for (int i=0;i<n-1;++i)
        res=res<dp[i]?res:dp[i];
    return 0*printf("%I64d\n",res);
}