P1005 矩阵取数游戏[区间dp]

时间:2022-08-27 16:45:01

题目描述

帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的\(m*n\)的矩阵,矩阵中的每个元素\(a_{i,j}\)均为非负整数。游戏规则如下:

  1. 每次取数时须从每行各取走一个元素,共n个。经过m次后取完矩阵内所有元素;
  2. 每次取走的各个元素只能是该元素所在行的行首或行尾;
  3. 每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值\(\times 2^i\)*,其中i表示第i次取数(从1开始编号);
  4. 游戏结束总得分为m次取数得分之和。

帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。

解析

除了脑残高精度(反正窝用__int128硬生生水了过去,但是考场上不能用啊),是道还行的dp题。

窝的做法比起其它题解的做法low了很多,时间和空间效率都不是十分优秀,而且也似乎有人用了,还比我快(哭。


观察题目,容易发现我们只能对每行分开进行\(dp\),而对每行的\(dp\)实际上就是一个区间\(dp\),从大区间缩小到小区间。

设\(dp[i][l][r][j]\)表示第\(i\)次取数时,第\(j\)行左边界取到\(l\),右边界取到\(r\)时的最优解。

得到状态转移方程:

\[dp[i][l][r][j]=\max\limits_{i \in [1,m],j \in [1,n]} \{dp[i-1][l-1][r][j]+num[j][l-1]*2^i,dp[i-1][l][r+1][j]+num[j][r+1]*2^i\}
\]

如果直接这么写会炸空间。

观察状态转移方程发现一个状态只与它上一个状态有关,于是考虑一个滚动数组优化。

参考代码

#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<ctime>
#include<cstdlib>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#define N 101
#define ll __int128
#define INF 0x7ffffffff
using namespace std;
ll dp[2][N][N][N],n,m,a[N][N];//dp[i][l][r][j]表示第i次取数第j行的最大得分,左端点l,右端点r
inline ll read()
{
int f=1,x=0;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
return x*f;
}
inline ll qp(ll a,ll b)//快速幂
{
ll ans=1;
for(;b;b>>=1){if(b&1)ans*=a;a*=a;}
return ans;
}
void print(ll x)//暴躁老哥,在线__int128
{
if(!x) return;
if(x) print(x/10);
putchar(x%10+'0');
}
int main()
{
n=read(),m=read();
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) a[i][j]=read();
memset(dp,~0x3f,sizeof(dp));
int now=0;
for(int i=1;i<=n;++i) dp[0][1][m][i]=0;//初始化,不细讲
for(int k=1;k<=m;++k){
now^=1;
for(int i=1;i<=n;++i){
for(int l=1;l<=k+1;++l){
ll r=l+m-k-1;
dp[now][l][r][i]=max(dp[now][l][r][i],max(dp[now^1][l-1][r][i]+a[i][l-1]*qp(2,k),dp[now^1][l][r+1][i]+a[i][r+1]*qp(2,k)));
}
}
}
ll ans=0;
for(int i=1;i<=n;++i){
ll maxx=-INF;
for(int l=1;l<=m;++l)、
//寻找每一行的最优解
maxx=max(maxx,max(dp[now][l][l+1][i],dp[now][l+1][l][i]));
//最后一步会出现两种状态,都要统计
ans+=maxx;
}
if(!ans) printf("0\n");
else print(ans);
return 0;
}