题意:给你一棵n 个节点的树,定义1到n的代价是1到 n节点间的最短路径的长度。
现在给你 m 组询问,让你添加一条边权为 w 的边(不与原图重复),求代价的最大值。询问之间相互独立。
1≤n,m≤3×1e5,1<=c[i]<=1e9,1<=w<=1e9
思路:网上dalao们的写法好像都和我不太一样……
考虑将1-n路径上所有的点取出,则原树变成了一条链和若干条子树
首先判断以链上某一点为根的子树size是否>=3,若是则可以在其内部连边,对最短路没有影响
若没有则考虑在链上的点或者其延伸出的一个点(size<=2)中取某两个点上连边
预处理出mx[u]代表u除链上儿子的子树最大深度
则对于x,y(x在上y在下)两个点来说相对于原方案,新的方案增加了mx[x]+mx[y]-2*dis[y]的长度
对于固定的y只需要维护mx[x]的前缀最大值
需要注意的是不能连原树中有的边,即mx[x]=dis[x]和mx[y]=dis[y]不能同时成立,否则相当于同时取到链上相邻的两点
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<map>
#include<set>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef vector<int> VI;
#define fi first
#define se second
#define MP make_pair
#define N 310000
#define M 51
#define MOD 1000000007
#define eps 1e-8
#define pi acos(-1)
#define oo 3e14 struct node
{
int x,cost;
node(int a,int b)
{
x=a;
cost=b;
}
}; ll d[N],mx[N],dis[N];
int flag[N],size[N],fa[N],b[N],q[N];
vector<node>c[N]; void dfs(int u)
{
flag[u]=;
for(int i=;i<=(int)c[u].size()-;i++)
{
int v=c[u][i].x;
if(!flag[v])
{
fa[v]=u;
dis[v]=dis[u]+c[u][i].cost;
dfs(v);
}
}
} void dfs2(int u)
{
flag[u]=size[u]=;
mx[u]=dis[u];
for(int i=;i<=(int)c[u].size()-;i++)
{
int v=c[u][i].x;
if(flag[v]==&&b[v]==)
{
dfs2(v);
size[u]+=size[v];
mx[u]=max(mx[u],mx[v]);
}
}
} int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=;i<=n;i++) c[i].clear();
for(int i=;i<=n-;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
c[x].push_back(node(y,z));
c[y].push_back(node(x,z));
}
memset(flag,,sizeof(flag));
dis[]=;
dfs();
memset(b,,sizeof(b));
int num=;
int k=n;
while(k!=)
{
q[++num]=k;
b[k]=;
k=fa[k];
}
q[++num]=; b[]=;
memset(flag,,sizeof(flag));
for(int i=;i<=num;i++) dfs2(q[i]); int p=;
for(int i=;i<=num;i++)
{
int u=q[i];
if(size[u]>=){p=; break;}
} for(int i=;i<=num/;i++) swap(q[i],q[num-i+]);
ll len=-oo;
for(int i=;i<=num;i++)
{
int u=q[i];
if(i>=)
{
int fa=q[i-];
if(mx[u]>dis[u]) len=max(len,mx[u]-dis[u]*+d[i-]);
else
{
if(mx[fa]>dis[fa]) len=max(len,mx[u]-dis[u]*+d[i-]);
else if(i>=)
{
int x=q[i-];
len=max(len,mx[u]-dis[u]*+d[i-]);
}
}
}
if(i==) d[i]=mx[u];
else d[i]=max(d[i-],mx[u]);
} for(int i=;i<=m;i++)
{
int x;
scanf("%d",&x);
if(p){printf("%lld\n",dis[n]); continue;}
printf("%lld\n",min(dis[n],dis[n]+len+x));
}
return ;
}