[CF833B] The Bakery

时间:2023-03-09 18:36:34
[CF833B] The Bakery

Description

将一个长度为n的序列分为k段

使得总价值最大一段区间的价值表示为区间内不同数字的个数

\(n\leq 35000,k\leq 50,1\leq a_i\leq n\)

Solution

定义 \(dp[i][j]\) 表示前 i 个里面分 j 段的最大收益

一个显然的 dp 方程是 \(dp[i][j]=\max \limits_{1\leq p<i} dp[p][j-1]+w(p+1,i)\)。复杂度 \(O(n^2k)\),GG。

考虑优化此方程,因为是取 max,容易想到放在线段树上实现。

同时定义 \(pre[a[i]]\) 表示当前 \(a[i]\) 这个元素上一次出现的位置是哪里,如果没有出现则是 0 。

难点在于 \(w\) 数组如何动态快速的求出来,我们外层循环一个 \(j\) 表示分的段数,发现如果当前扫到 i 这个位置那么 a[i] 的贡献实际上是让 \([pre[a[i]],i]\) 这段区间整体加一。可以这么理解,就是当前扫到 i,那么对于所有到 i 截至的区间 \([p,i]\),a[i] 这个元素对这些区间有贡献的部分是左端点\(\in [pre[a[i]],i]\) 里的这一段。线段树区间加就好了。也就是说,当前扫到了 i ,那么线段树的叶子节点 p 表示的就是 \(w[p,i]\) 的值,这也是我们用线段树的意义所在。这样就可以 \(O(nlogn)\) 求出 w 数组了。同时 dp 数组实时更新即可。

还有一点要注意的是方程是 \(dp[p][j-1]+w(p+1,i)\) ,也就是说能用来更新答案的是 节点 p 的 dp 值和 p+1 的累加值,有点麻烦,干脆把所有的 dp 值都往左挪一个就行了,也就是叶子节点 p 表示的实际上是 p+1 的值。感觉有点绕。。。

Code

#include<cstdio>
#include<cctype>
#include<cstring>
#define K 55
#define N 35005
#define min(A,B) ((A)<(B)?(A):(B))
#define max(A,B) ((A)>(B)?(A):(B))
#define swap(A,B) ((A)^=(B)^=(A)^=(B)) int n,k;
int f[N];
int val[N];
int pre[N];
int mx[N<<2];
int lazy[N<<2]; int getint(){
int x=0,f=0;char ch=getchar();
while(!isdigit(ch)) f|=ch=='-',ch=getchar();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
} void build(int cur,int l,int r){
if(l==r){
mx[cur]=f[l-1];
return;
}
int mid=l+r>>1;
build(cur<<1,l,mid);
build(cur<<1|1,mid+1,r);
mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
} void pushdown(int cur){
if(!lazy[cur]) return;
lazy[cur<<1]+=lazy[cur];
lazy[cur<<1|1]+=lazy[cur];
mx[cur<<1]+=lazy[cur];
mx[cur<<1|1]+=lazy[cur];
lazy[cur]=0;
} void modify(int cur,int l,int r,int ql,int qr){
if(!ql or !qr or ql>qr) return;
if(ql<=l and r<=qr){
mx[cur]++;
lazy[cur]++;
return;
}
pushdown(cur);
int mid=l+r>>1;
if(ql<=mid)
modify(cur<<1,l,mid,ql,qr);
if(mid<qr)
modify(cur<<1|1,mid+1,r,ql,qr);
mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
} int query(int cur,int l,int r,int ql,int qr){
if(ql<=l and r<=qr)
return mx[cur];
pushdown(cur);
int mid=l+r>>1,ans=0;
if(ql<=mid){
int p=query(cur<<1,l,mid,ql,qr);
ans=max(ans,p);
}
if(mid<qr){
int p=query(cur<<1|1,mid+1,r,ql,qr);
ans=max(ans,p);
}
return ans;
} signed main(){
n=getint(),k=getint();
for(int i=1;i<=n;i++)
val[i]=getint();
for(int j=1;j<=k;j++){
memset(mx,0,sizeof mx);
memset(pre,0,sizeof pre);
memset(lazy,0,sizeof lazy);
build(1,1,n);
for(int i=1;i<=n;i++){
modify(1,1,n,pre[val[i]]+1,i);
pre[val[i]]=i;
//if(i<j) continue;
f[i]=query(1,1,n,1,i);
//printf("j=%d,i=%d,f=%d\n",j,i,f[i]);
}
}
printf("%d\n",f[n]);
return 0;
}