
首先枚举出所有可能成为区间最小差值的点对$(j,i)$。
枚举每个位置作为右端点$i$,假设$a[j]>a[i]$。
找到第一个这样的$j$,那么可以将下一个$a[j]$的范围缩小到$(a[i],\frac{a[i]+a[j]}{2})$。这是因为在这之外的数要么没有$j$优,要么会被$j$考虑到。
利用可持久化线段树可以很容易地找到下一个$j$的位置,最多$O(n\log n)$个点对,时间复杂度$O(n\log^2n)$。
接下来的问题等价于选择$k$条不相交线段,使得价值和最小。
将线段按左端点从小到大排序,设$f[i][j]$表示考虑前$i$条线段,选择了$j$条线段的最优价值,可以通过双指针优化到$O(kn\log n)$。
注意到$f[all][j]$是个凸函数,故可以二分斜率$mid$来切它,具体体现为每选一条线段,价值就多加$mid$。
那么随着$mid$的增大,最优解中选择的线段数目会越来越少。
二分找到最优解中线段数目最接近$k$的$mid$即可。
时间复杂度$O(n\log^2n)$。
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=50010,M=N*18;
int n,m,K,i,a[N],tot,T[N],l[M],r[M],v[M],tmp,ans,s[N],g[M*2];double L,R,MID,f[M*2];
struct E{int l,r,v;E(){}E(int _l,int _r,int _v){l=_l,r=_r,v=_v;}}e[M*2];
inline bool cmp(const E&a,const E&b){return a.l<b.l;}
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
int ins(int x,int a,int b,int c,int p){
int y=++tot;
v[y]=p;
if(a==b)return y;
int mid=(a+b)>>1;
if(c<=mid)l[y]=ins(l[x],a,mid,c,p),r[y]=r[x];
else l[y]=l[x],r[y]=ins(r[x],mid+1,b,c,p);
return y;
}
void ask(int x,int a,int b,int c,int d){
if(!x)return;
if(c<=a&&b<=d){
if(v[x]>tmp)tmp=v[x];
return;
}
int mid=(a+b)>>1;
if(c<=mid)ask(l[x],a,mid,c,d);
if(d>mid)ask(r[x],mid+1,b,c,d);
}
inline void findbigger(int x){
int l=a[x]+1,r=n,t=x-1;
while(l<=r&&t){
tmp=0;
ask(T[t],1,n,l,r);
if(!tmp)return;
t=tmp;
e[++m]=E(t,x,a[t]-a[x]);
r=(a[x]+a[t--]-1)>>1;
}
}
inline void findsmaller(int x){
int l=1,r=a[x]-1,t=x-1;
while(l<=r&&t){
tmp=0;
ask(T[t],1,n,l,r);
if(!tmp)return;
t=tmp;
e[++m]=E(t,x,a[x]-a[t]);
l=(a[x]+a[t--]+2)>>1;
}
}
inline void up(int&x,int y){if(f[x]>f[y])x=y;}
inline void cal(){
int i,j;
for(i=1;i<=n;i++)s[i]=0;
for(i=1,j=ans=0;i<=m;i++){
while(j+1<e[i].l){
j++;
up(s[j],s[j-1]);
}
f[i]=f[s[j]]+e[i].v+MID;
g[i]=g[s[j]]+1;
up(s[e[i].r],i);
up(ans,i);
}
}
int main(){
read(n),read(K);
for(i=1;i<=n;i++)read(a[i]),T[i]=ins(T[i-1],1,n,a[i],i);
for(i=1;i<=n;i++)findbigger(i),findsmaller(i);
sort(e+1,e+m+1,cmp);
L=-1e9,R=1e9;
for(int _=80;_;_--){
MID=(L+R)/2;
cal();
if(g[ans]==K)break;
if(g[ans]<K)R=MID;else L=MID;
}
return printf("%.0f",f[ans]-MID*K),0;
}