【UOJ347】【WC2018】通道 边分治 虚树 DP

时间:2023-01-20 05:09:21

题目大意

  给你三棵树,点数都是\(n\)。求

\[\max_{i,j}d_1(i,j)+d_2(i,j)+d_3(i,j)
\]

  其中\(d_k(i,j)\)是在第\(k\)棵数中\(i,j\)两点之间的距离。

  \(n\leq 100000\)

题解

  设\(d(i,j)=d_1(i,j)+d_2(i,j)+d_3(i,j),h_k(i)\)为\(i\)号点在第\(k\)棵树上的深度

一棵树

  树形DP。

  时间复杂度:\(O(n)\)

两棵树

  这是一道集训队自选题。

  点分治+动态点分治

  设这两个点在第一棵树中的LCA是\(p\),那么\(d(i,j)=h_1(i)+h_1(j)-2h_1(p)+d_2(i,j)\)

  在第二棵树中,对于每个点\(i\),建立一个新点\(i'\),在\(i\)和\(i'\)之间连一条边权为\(h_1(i)\)的边。

  这样\(d(i,j)=d_2(i,j)-2h_1(p)\)

  我们从下往上枚举\(p\),每次查询这棵子树的点在第二棵树中的直径。

  合并直径可以直接合并两个端点。

  时间复杂度:\(O(n\log n)\)

两棵树+一条链

  考虑对链分治。

  每次只求经过当前链中间那个点(或者那条边)的答案。

  \(d(i,j)=h_1(i)+h_1(j)-2h_1(p)+d_2(i,j)+|l_i-l_j|\)

三棵树

  考虑对第三棵树进行边分治。

  先把第三棵树转成二叉树

  然后直接边分治就行了。

  因为每个点的度数\(\leq 3\),所以边分治的复杂度是对的。

  求LCA可以用dfs序+ST表。

  还要维护当前部分在第一棵树的dfs序。

  时间复杂度:\(O(n\log n)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
void sort(int &a,int &b)
{
if(a>b)
swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
do
{
s=s*10+c-'0';
}
while((c=getchar())>='0'&&c<='9');
return s;
}
void put(int x)
{
if(!x)
{
putchar('0');
return;
}
static int c[20];
int t=0;
while(x)
{
c[++t]=x%10;
x/=10;
}
while(t)
putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
if(b<a)
{
a=b;
return 1;
}
return 0;
}
int upmax(int &a,int b)
{
if(b>a)
{
a=b;
return 1;
}
return 0;
}
int n;
vector<pil> g1[400010],g2[400010],g3[400010],g4[400010];
int lastson[400010];
int f1[400010];
int f2[400010];
int f3[400010];
ll d1[400010];
ll d2[400010];
ll d3[400010];
int dep1[400010];
ll w3[400010];
int st[400010];
int ed[400010];
int st1[400010];
int ed1[400010];
pli fs[21][400010];
pii fs1[21][200010];
int lo[400010];
int ti;
int ti1;
void dfs1(int x,int fa,ll dep,int dep2)
{
f1[x]=fa;
d1[x]=dep;
dep1[x]=dep2;
fs1[0][++ti1]=pii(dep2,x);
st1[x]=ti1;
for(auto v:g1[x])
if(v.first!=fa)
{
dfs1(v.first,x,dep+v.second,dep2+1);
fs1[0][++ti1]=pii(dep2,x);
}
ed1[x]=ti1;
}
void dfs2(int x,int fa,ll dep)
{
f2[x]=fa;
d2[x]=dep;
fs[0][++ti]=pli(dep,x);
st[x]=ti;
for(auto v:g2[x])
if(v.first!=fa)
{
dfs2(v.first,x,dep+v.second);
fs[0][++ti]=pli(dep,x);
}
ed[x]=ti;
}
void dfs3(int x,int fa,ll dep)
{
f3[x]=fa;
d3[x]=dep;
for(auto v:g3[x])
if(v.first!=fa)
{
w3[v.first]=v.second;
dfs3(v.first,x,dep+v.second);
}
}
void buildst()
{
int i,j;
for(i=1;i<=20;i++)
for(j=1;j+(1<<i)-1<=ti;j++)
fs[i][j]=min(fs[i-1][j],fs[i-1][j+(1<<(i-1))]);
for(i=1;i<=20;i++)
for(j=1;j+(1<<i)-1<=ti1;j++)
fs1[i][j]=min(fs1[i-1][j],fs1[i-1][j+(1<<(i-1))]);
lo[1]=0;
for(i=2;i<=ti;i++)
lo[i]=lo[i>>1]+1;
}
int queryst1(int x,int y)
{
int t=lo[y-x+1];
return min(fs1[t][x],fs1[t][y-(1<<t)+1]).second;
}
int querylca1(int x,int y)
{
if(st1[x]>st1[y])
swap(x,y);
return queryst1(st1[x],ed1[y]);
}
int queryst(int x,int y)
{
int t=lo[y-x+1];
return min(fs[t][x],fs[t][y-(1<<t)+1]).second;
}
int querylca(int x,int y)
{
if(st[x]>st[y])
swap(x,y);
return queryst(st[x],ed[y]);
}
ll c[400010];
ll querydist(int x,int y,ll z=0)
{
if(!x&&!y)
return -1;
if(!x||!y)
return 0;
return d2[x]+d2[y]-2*d2[querylca(x,y)]+c[x-n]+c[y-n]-2*z;
}
struct graph
{
int v[400010];
int t[400010];
int b[400010];
ll w[400010];
int h[200010];
int n;
graph()
{
n=0;
}
void add(int x,int y,ll z)
{
n++;
v[n]=y;
w[n]=z;
t[n]=h[x];
h[x]=n;
}
};
graph g;
void init()
{
int i;
int x,y;
ll z;
for(i=1;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);
g1[x].push_back(pil(y,z));
g1[y].push_back(pil(x,z));
}
for(i=1;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);
g2[x].push_back(pil(y,z));
g2[y].push_back(pil(x,z));
}
for(i=1;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);
g3[x].push_back(pil(y,z));
g3[y].push_back(pil(x,z));
}
dfs1(1,0,0,0);
for(i=1;i<=n;i++)
{
g2[i].push_back(pil(i+n,d1[i]));
g2[i+n].push_back(pil(i,d1[i]));
}
dfs2(1,0,0);
buildst();
dfs3(1,0,0);
for(i=1;i<=n;i++)
{
g.add(i,i+n,w3[i]);
g.add(i+n,i,w3[i]);
if(f3[i])
{
if(lastson[f3[i]])
{
g.add(i+n,lastson[f3[i]],0);
g.add(lastson[f3[i]],i+n,0);
}
else
{
g.add(i+n,f3[i],0);
g.add(f3[i],i+n,0);
}
lastson[f3[i]]=i+n;
}
}
}
int cmp1(int x,int y)
{
return st1[x]<st1[y];
}
int b[400010];
ll ans=0;
struct pp
{
int x,y;
ll v;
pp(int a=0,int b=0,ll c=-1)
{
x=a;
y=b;
v=c;
}
};
int operator >(pp a,pp b){return a.v>b.v;}
int operator <(pp a,pp b){return a.v<b.v;}
typedef pair<pp,pp> ppp;
ppp f[400010];
int x1,x2,num,sz;
ll xv;
int s[400010];
int tag[400010];
void dfs11(int x,int fa)
{
int i;
s[x]=1;
for(i=g.h[x];i;i=g.t[i])
if(!g.b[i]&&g.v[i]!=fa)
{
dfs11(g.v[i],x);
s[x]+=s[g.v[i]];
}
}
void dfs12(int x,int fa)
{
int i;
for(i=g.h[x];i;i=g.t[i])
if(!g.b[i]&&g.v[i]!=fa)
{
int mx=max(s[g.v[i]],num-s[g.v[i]]);
if(mx<sz)
{
sz=mx;
x1=x;
x2=g.v[i];
xv=g.w[i];
}
dfs12(g.v[i],x);
}
}
int op(int x)
{
return ((x-1)^1)+1;
}
void dfs13(int x,int fa,int b=1)
{
tag[x]=b;
int i;
for(i=g.h[x];i;i=g.t[i])
if(!g.b[i]&&g.v[i]!=fa)
{
int t=b;
if(g.v[i]==x2)
{
t=2;
g.b[i]=1;
g.b[op(i)]=1;
}
dfs13(g.v[i],x,t);
}
}
void dfs14(int x,int fa,ll dep)
{
c[x]=dep;
int i;
for(i=g.h[x];i;i=g.t[i])
if(!g.b[i]&&g.v[i]!=fa)
dfs14(g.v[i],x,dep+g.w[i]);
}
int sta[400010];
int top;
void updateans(pp a,pp b,ll z)
{
ans=max(ans,querydist(a.x,b.x,z));
ans=max(ans,querydist(a.x,b.y,z));
ans=max(ans,querydist(a.y,b.x,z));
ans=max(ans,querydist(a.y,b.y,z));
}
pp getmax(pp a,pp b,ll z)
{
return max(max(max(pp(a.x,a.y,querydist(a.x,a.y,z)),pp(a.x,b.x,querydist(a.x,b.x,z))),pp(a.x,b.y,querydist(a.x,b.y,z))),max(max(pp(b.x,b.y,querydist(b.x,b.y,z)),pp(a.y,b.x,querydist(a.y,b.x,z))),pp(a.y,b.y,querydist(a.y,b.y,z))));
}
void update(int x,int y)
{
updateans(f[x].first,f[y].second,d1[y]);
updateans(f[x].second,f[y].first,d1[y]);
f[y].first=getmax(f[x].first,f[y].first,d1[y]);
f[y].second=getmax(f[x].second,f[y].second,d1[y]);
}
void solve(int x,vector<int> &q)
{
if(q.empty())
return;
dfs11(x,0);
if(s[x]<=1)
return;
num=s[x];
sz=0x7fffffff;
dfs12(x,0);
dfs13(x,0);
dfs14(x1,0,0);
dfs14(x2,0,xv);
int last=0;
top=0;
int v1=q.front();
int v2=q.back();
int vlca=querylca1(v1,v2);
if(vlca!=v1)
{
sta[++top]=vlca;
f[vlca]=ppp();
}
for(auto v:q)
{
if(tag[v]==1)
f[v]=ppp(pp(v+n,0,0),pp());
else
f[v]=ppp(pp(),pp(v+n,0,0));
if(last)
{
int lca=querylca1(last,v);
while(dep1[lca]<dep1[sta[top]])
{
if(dep1[lca]<=dep1[sta[top-1]])
{
update(sta[top],sta[top-1]);
top--;
}
else
{
f[lca]=ppp();
update(sta[top],lca);
top--;
sta[++top]=lca;
}
}
}
sta[++top]=v;
last=v;
}
while(top>=2)
{
update(sta[top],sta[top-1]);
top--;
}
vector<int> q1,q2;
for(auto v:q)
if(tag[v]==1)
q1.push_back(v);
else
q2.push_back(v);
v1=x1;
v2=x2;
solve(v1,q1);
solve(v2,q2);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("uoj347.in","r",stdin);
freopen("uoj347.out","w",stdout);
#endif
scanf("%d",&n);
init();
vector<int> ss;
int i;
for(i=1;i<=n;i++)
ss.push_back(i);
sort(ss.begin(),ss.end(),cmp1);
solve(1,ss);
printf("%lld\n",ans);
return 0;
}