bzoj 2599

时间:2023-03-08 17:28:15

还是点对之间的问题,果断上点分治

同样,把一条路径拆分成经过根节点的两条路径,对不经过根节点的路径递归处理

然后,我们逐个枚举根节点的子树,计算出子树中某一点到根节点的距离,然后在之前已经处理过的点中找,看有没有距离之和等于k的,如果有就取最小值(这里用桶维护即可)

然后再把这个子树内的信息扔进桶里,计算下一棵子树即可

但是注意,在递归处理之前需要把桶清空!

然后就没啥了

不合法的情况就是无法更新出答案,输出-1即可

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;
const int inf=0x3f3f3f3f;
struct Edge
{
int next;
int to;
int val;
}edge[];
int head[];
int cnt=;
int n,k;
int s,rt;
int siz[];
int maxp[];
bool vis[];
int dep[];
int dis[];
int has[];
int ans=0x3f3f3f3f;
void init()
{
memset(head,-,sizeof(head));
cnt=;
}
void add(int l,int r,int w)
{
edge[cnt].next=head[l];
edge[cnt].to=r;
edge[cnt].val=w;
head[l]=cnt++;
}
void get_rt(int x,int fa)
{
siz[x]=,maxp[x]=;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(to==fa||vis[to])continue;
get_rt(to,x);
siz[x]+=siz[to];
maxp[x]=max(maxp[x],siz[to]);
}
maxp[x]=max(maxp[x],s-siz[x]);
if(maxp[x]<maxp[rt])rt=x;
}
void calc(int x,int fa)
{
if(dis[x]<=k)ans=min(ans,dep[x]+has[k-dis[x]]);
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(to==fa||vis[to])continue;
dep[to]=dep[x]+,dis[to]=dis[x]+edge[i].val;
calc(to,x);
}
}
void update(int x,int fa)
{
if(dis[x]<=k)has[dis[x]]=min(has[dis[x]],dep[x]);
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(to==fa||vis[to])continue;
update(to,x);
}
}
void erase(int x,int fa)
{
if(dis[x]<=k)has[dis[x]]=inf;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(to==fa||vis[to])continue;
erase(to,x);
}
}
void solve(int x)
{
vis[x]=,has[]=;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(vis[to])continue;
dep[to]=,dis[to]=edge[i].val;
calc(to,);
update(to,);
}
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(vis[to])continue;
erase(to,);
}
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(vis[to])continue;
s=siz[to],rt=,maxp[rt]=inf;
get_rt(to,);
solve(rt);
}
}
int main()
{
scanf("%d%d",&n,&k);
init();
for(int i=;i<=k;i++)has[i]=n;
for(int i=;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
x++,y++;
add(x,y,z),add(y,x,z);
}
ans=maxp[rt]=s=n;
get_rt(,);
solve(rt);
ans=(ans==n)?-:ans;
printf("%d\n",ans);
return ;
}