bzoj 5252: [2018多省省队联测]林克卡特树

时间:2024-12-25 19:07:26

Description

小L 最近沉迷于塞尔达传说:荒野之息(The Legend of Zelda: Breath of The Wild)无法自拔,他尤其喜欢游戏中的迷你挑战。

游戏中有一个叫做“LCT” 的挑战,它的规则是这样子的:现在有一个N 个点的 树(Tree),每条边有一个整数边权vi ,若vi >= 0,表示走这条边会获得vi 的收益;若vi < 0 ,则表示走这条边需要支付- vi 的过路费。小L 需要控制主角Link 切掉(Cut)树上的 恰好K 条边,然后再连接 K 条边权为 0 的边,得到一棵新的树。接着,他会选择树上的两个点p; q ,并沿着树上连接这两点的简单路径从p 走到q ,并为经过的每条边支付过路费/ 获取相应收益。

海拉鲁大陆之神TemporaryDO 想考验一下Link。他告诉Link,如果Link 能切掉 合适的边、选择合适的路径从而使 总收益 - 总过路费最大化的话,就把传说中的大师之剑送给他。

小 L 想得到大师之剑,于是他找到了你来帮忙,请你告诉他,Link 能得到的 总收益 - 总过路费最大是多少。

Solution

凸优化

发现斜率是递增的,虽然后面可能会变成负的,二分这个斜率 \(mid\)

然后相当于选择若干条点不相交的路径,每一次选择的代价为 \(mid\),求 \((K,f(K))\) 的值

我们二分到 \((K,f(K))\) 所在直线的斜率就可以推出这个点的值

求函数的最大值和 \(60\) 分 \(DP\) 一样,设 \(f[x][0/1/2]\) 表示这个点的度数为 \(0/1/2\) 时的最大值

转移时顺便维护选择的路径数量,具体转移见代码

#include<bits/stdc++.h>
using namespace std;
template<class T>void gi(T &x){
int f;char c;
for(f=1,c=getchar();c<'0'||c>'9';c=getchar())if(c=='-')f=-1;
for(x=0;c<='9'&&c>='0';c=getchar())x=x*10+(c&15);x*=f;
}
typedef long long ll;
const int N=3e5+10;const ll inf=1e15;
int n,K,head[N],nxt[N*2],to[N*2],num=0,dis[N*2];ll reb,mid;
inline void link(int x,int y,int z){
nxt[++num]=head[x];to[num]=y;head[x]=num;dis[num]=z;
nxt[++num]=head[y];to[num]=x;head[y]=num;dis[num]=z;
}
struct data{int x;ll y;}f[N][3];
inline bool operator <(const data &p,const data &q){
return p.y!=q.y?p.y<q.y:p.x>q.x;
}
inline data operator +(const data &p,const data &q){
return (data){p.x+q.x,p.y+q.y};
}
inline data operator +(const data &p,const ll q){return (data){p.x,p.y+q};}
inline void upd(data &a,data b,int c){b.x+=c;a=max(a,b);}
inline void dfs(int x,int last){
f[x][0]=(data){0,0};f[x][1]=(data){1,-mid};f[x][2]=(data){0,-inf};
for(int i=head[x],u;i;i=nxt[i]){
if((u=to[i])==last)continue;
dfs(u,x);
data w=max(f[u][0],max(f[u][1],f[u][2]));
upd(f[x][2],f[x][2]+w,0);
upd(f[x][2],f[x][1]+f[u][1]+mid+dis[i],-1); upd(f[x][1],f[x][1]+w,0);
upd(f[x][1],f[x][0]+f[u][1]+dis[i],0); upd(f[x][0],f[x][0]+w,0);
}
}
inline int solve(){
dfs(1,1);
data ans=(data){0,-inf};
for(int i=1;i<=n;i++)ans=max(ans,max(f[i][0],max(f[i][1],f[i][2])));
if(ans.x<=K)reb=ans.y;
return ans.x;
}
int main(){
freopen("pp.in","r",stdin);
freopen("pp.out","w",stdout);
int x,y,z;
ll l=0,r=0,ret=0;
cin>>n>>K;K++;
for(int i=1;i<n;i++){
gi(x);gi(y);gi(z);
link(x,y,z);l-=abs(z);r+=abs(z);
}
while(l<=r){
mid=(l+r)>>1;
if(solve()<=K)ret=mid,r=mid-1;
else l=mid+1;
}
reb+=ret*K;
cout<<reb<<endl;
return 0;
}