D - The Bakery
这个题目好难啊,我理解了好久,都没有怎么理解好,
这种线段树优化dp,感觉还是很难的。
直接说思路吧,说不清楚就看代码吧。
这个题目转移方程还是很好写的,
dp[i][j]表示前面 i 个蛋糕 分成了 j 个数字的最大价值。
dp[i][j]=max(dp[k][j-1]+val[k+1~i])
显而易见的是,这个肯定不可以直接暴力求,所以就要用到线段树优化。
线段树怎么优化呢,
先看这个问题,给你一个点 x ,问你以这个点为右端点的所有区间有多少种数字,
这个很简单是不是,那继续问你 从x 到 x+1 这个点怎么转移?
是不是找到 last[a[x+1]] 上一次出现a[x+1] 这个数字的位置,从这个位置+1到 x+1 这个位置,所有的区间都+1
这个是不是就是线段树的更新,那么线段树的每一个位置是不是随着我们对 i 的枚举,每一个叶子节点 就是l==r==k 是不是 val[k~i]
知道这个了,回到之前的问题,我们要求val[k+1~i]+dp[k][j-1]的最大值
因为这个dp[k][j-1]上一次已经求出来了,对这一次不产生任何影响了,是一个定值。
我们就只需要求val[k+1~j]
所以可以把这两个东西一起放到线段树里面,但是一个是l==r==k这个位置,一个是k+1这个位置,所以需要val往前面挪一下,或者dp[k]往后挪一下。
我选择第一种,那么就是每次更新,就更新 last[a[x+1]] 到 x 这个位置。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <string>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn = 4e4+ 10;
int maxs[maxn * 4], lazy[maxn * 4];
int dp[maxn];
void push_up(int id)
{
maxs[id] = max(maxs[id << 1], maxs[id << 1 | 1]);
} void build(int id,int l,int r)
{
lazy[id] = 0;
maxs[id] = 0;
if(l==r)
{
maxs[id] = dp[l];
return;
}
int mid = (l + r) >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
push_up(id);
} void push_down(int id)
{
if (lazy[id] == 0) return;
maxs[id << 1] += lazy[id];
maxs[id << 1 | 1] += lazy[id]; lazy[id << 1] += lazy[id];
lazy[id << 1 | 1] += lazy[id]; lazy[id] = 0;
} void update(int id,int l,int r,int x,int y,int val)
{
// printf("id=%d l=%d r=%d x=%d y=%d val=%d\n", id, l, r, x, y, val);
if(x<=l&&y>=r)
{
maxs[id] += val;
lazy[id] += val;
return;
}
push_down(id);
int mid = (l + r) >> 1;
if (x <= mid) update(id << 1, l, mid, x, y, val);
if (y > mid) update(id << 1 | 1, mid + 1, r, x, y, val);
push_up(id);
} int query(int id,int l,int r,int x,int y)
{
if (x <= l && y >= r) return maxs[id];
push_down(id);
int ans = 0, mid = (l + r) >> 1;
if (x <= mid) ans = max(ans, query(id << 1, l, mid, x, y));
if (y > mid) ans = max(ans, query(id << 1 | 1, mid + 1, r, x, y));
return ans;
}
int last[maxn];
int a[maxn];
int main()
{
int n, k;
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
for(int j=1;j<=k;j++)
{
memset(last, 0, sizeof(last));
build(1, 0, n);
for(int i=1;i<=n;i++)
{
update(1, 0, n, last[a[i]], i - 1, 1);
last[a[i]] = i;
dp[i] = query(1, 0, n, 0, i - 1);
}
}
printf("%d\n", dp[n]);
return 0;
}