斜率dp+cdq分治

时间:2023-03-10 05:35:14
斜率dp+cdq分治

写在前面

这个东西应该是一个非常重要的套路......所以我觉得必须写点什么记录一下,免得自己忘掉了

一直以来我的斜率dp都掌握的不算很好......也很少主动地在比赛里想到

写这个的契机是noi.ac在今天的考试中考了一道用这玩意儿的原题,被我搞出来了,于是决定总结一下(毕竟见得越来越多)

斜率dp

考虑一个常见的二次复杂度的dp:

$dp[i]=min(dp[j]+c(i)+g(j)+k(i)*f(j))$

其中$c,g,k,f$都是只和括号里的$i,j$有关的一元函数

一个很重要的思想是:看到n方dp的时候先想想能不能搞成这个样子的式子

如果搞出来了,这个东西一定可以在$O(n\log n)$的时间里面做出来——用cdq分治

怎么cdq

我们先给这四个函数名字:

$c(i)$是额外附加的只和$i$有关的常数

$f(i)=x(i)$作为横坐标

$g(i)=y(i)$作为纵坐标

$k(i)$是$i$这一点上的转移斜率

首先把所有点按照斜率排序

对于过程solve(l,r),这样操作:

首先,按照输入编号,把(l,r)分成两半,然后递归处理solve(l,mid)

返回的是一个按照横坐标排好序的原数组(dp值都知道了的)

我们把这一批东西做一个上凸包(或者下凸包,依照要求max还是min变化)

然后对于后面那一半点我们用前面这个凸包更新答案,一个指针遍历右边一半,另一个指针遍历左边的凸包,每次跳到最优位置为止

这之后,我们递归处理右半部分

最后我们再对这两半归并排序,按照横坐标

什么意义?

实际上这一波操作中,有三个中间被我们排了序的元素:输入编号,斜率,横坐标

实际上就是一个三维偏序:因为不像普通的斜率dp那样横纵坐标或者斜率有单调性,所以我们强行cdq

这样,在每一次更新后一半的时候,前一半都是做完的,而且已经横坐标单调了

例题:2019.3.16 problemB

朴素n方dp很好看出来,然后发现可以直接套到上面式子里面

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#define head DEEP_DARK_FANTASY
#define ll long long
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') flag=-1;
ch=getchar();
}
while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
int n;;
struct node{
ll w,h,c,x,y,k,dp,num;
}a[100010],tmp[100010],q[100010];
inline bool cmp1(node l,node r){
return l.k<r.k;
}
void solve(int l,int r){
if(l==r){
a[l].x=a[l].h;
a[l].y=a[l].dp-a[l].w+a[l].h*a[l].h;
return;
}
int mid=(l+r)>>1,tl,tr,head,tail,i; tl=tr=0;
for(i=l;i<=r;i++){
if(a[i].num<=mid) tmp[++tl]=a[i];
else q[++tr]=a[i];
}
for(i=l;i<=mid;i++) a[i]=tmp[i-l+1];
for(i=mid+1;i<=r;i++) a[i]=q[i-mid]; solve(l,mid); head=1,tail=0;
for(i=l;i<=mid;i++){
while(tail>head&&(q[tail].y-q[tail-1].y)*(a[i].x-q[tail].x)>=(q[tail].x-q[tail-1].x)*(a[i].y-q[tail].y)) tail--;
q[++tail]=a[i];
} tl=1;
for(i=mid+1;i<=r;i++){
while(tl<tail&&a[i].k*(q[tl+1].x-q[tl].x)>=(q[tl+1].y-q[tl].y)) tl++;
a[i].dp=min(a[i].dp,-q[tl].x*a[i].k+q[tl].y+a[i].c);
} solve(mid+1,r); tl=l;tr=mid+1;head=0;
while(tl<=mid&&tr<=r){
if(a[tl].x==a[tr].x) tmp[++head]=((a[tl].y>a[tr].y)?a[tr++]:a[tl++]);
else tmp[++head]=((a[tl].x>a[tr].x)?a[tr++]:a[tl++]);;
}
while(tl<=mid) tmp[++head]=a[tl++];
while(tr<=r) tmp[++head]=a[tr++];
for(i=l;i<=r;i++) a[i]=tmp[i-l+1];
}
int main(){
n=read();int i;
for(i=1;i<=n;i++){
a[i].h=read();
a[i].dp=1e18;
a[i].num=i;
}
for(i=1;i<=n;i++){
a[i].w=read();
a[i].w+=a[i-1].w;
}
for(i=1;i<=n;i++){
a[i].c=a[i].h*a[i].h+a[i-1].w;
a[i].k=2ll*a[i].h;
}
a[1].dp=0;
sort(a+1,a+n+1,cmp1);
solve(1,n);
for(i=1;i<=n;i++)
if(a[i].num==n) printf("%lld\n",a[i].dp);
}