51nod1812树的双直径(换根树DP)

时间:2021-02-12 18:22:04

传送门:http://www.51nod.com/Challenge/Problem.html#!#problemId=1812

题解:头一次写换根树DP。

求两条不相交的直径乘积最大,所以可以这样考虑:把一条边割掉,然后分别求两棵子树内的最长链乘起来就行了。由于负负得正,所以要再求一次最短链,就是把边权全部取负求一下就行了。然后就能通过dfs维护子树i内的答案dn[i]和不含以i为根的子树的答案up[i],dn[i]很好维护,重点是维护up[i],共5种可能:(1)从父亲的up继承过来(2)前后缀中的最大值f+出边+入边(3)父亲的g+兄弟节点中最大的f+出边(4)前驱/后继中的最大和次大(5)前驱/后继中的子树中的直径。然后转移状态就行了。

细节太多……还要__int128。为了方便,计算时答案用long long维护,乘起来再转long long……

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e5+;
int n,tot,hd[N],v[N<<],w[N<<],nxt[N<<],p[N<<],len[N<<];
ll f[N],g[N],pre[N],suf[N],dn[N],up[N];
__int128 ans;
void print(__int128 x){if(x>)print(x/);putchar(''+x%);}
void add(int x,int y,int z){v[++tot]=y,nxt[tot]=hd[x],hd[x]=tot,w[tot]=z;}
void dfs1(int u, int fa)
{
f[u]=dn[u]=;
for(int i=hd[u];i;i=nxt[i])
if(v[i]!=fa)
{
dfs1(v[i],u);
dn[u]=max(dn[u],f[u]+f[v[i]]+w[i]);
f[u]=max(f[u],f[v[i]]+w[i]);
dn[u]=max(dn[u],dn[v[i]]);
}
}
void dfs2(int u,int fa)
{
int cnt=;
for(int i=hd[u];i;i=nxt[i])if(v[i]!=fa)p[++cnt]=v[i],len[cnt]=w[i];
pre[]=suf[cnt+]=;
for(int i=;i<=cnt;i++)pre[i]=max(pre[i-],f[p[i]]+len[i]);
for(int i=cnt;i;i--)suf[i]=max(suf[i+],f[p[i]]+len[i]);
/*一个点向上的直径:
(1)从父亲的up继承过来
(2)前后缀中的最大值f+出边+入边
(3)父亲的g+兄弟节点中最大的f+出边
(4)前驱/后继中的最大和次大
(5)前驱/后继中的子树中的直径*/
for(int i=;i<=cnt;i++)
{
g[p[i]]=max(g[p[i]],g[u]+len[i]);
g[p[i]]=max(g[p[i]],max(pre[i-],suf[i+])+len[i]);
up[p[i]]=max(up[p[i]],up[u]);
up[p[i]]=max(up[p[i]],pre[i-]+suf[i+]);
up[p[i]]=max(up[p[i]],g[u]+max(pre[i-],suf[i+]));
}
ll mx1=-1e18,mx2=-1e18,mx=-1e18,tmp;
for(int i=;i<=cnt;i++)
{
up[p[i]]=max(up[p[i]],max(mx1+mx2,mx));
tmp=f[p[i]]+len[i];
if(tmp>mx1)mx2=mx1,mx1=tmp;else if(tmp>mx2)mx2=tmp;
mx=max(mx,dn[p[i]]);
}
mx1=mx2=mx=-1e18;
for(int i=cnt;i;i--)
{
up[p[i]]=max(up[p[i]],max(mx1+mx2,mx));
tmp=f[p[i]]+len[i];
if(tmp>mx1)mx2=mx1,mx1=tmp;else if(tmp>mx2)mx2=tmp;
mx=max(mx,dn[p[i]]);
}
for(int i=hd[u];i;i=nxt[i])if(v[i]!=fa)dfs2(v[i],u);
}
int main()
{
scanf("%d",&n);
for(int i=,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),add(x,y,z),add(y,x,z);
dfs1(,),dfs2(,);
for(int i=;i<=n;i++)ans=max(ans,(__int128)dn[i]*up[i]);
for(int i=;i<=tot;i++)w[i]=-w[i];
memset(up,,sizeof up);
memset(g,,sizeof g);
dfs1(,),dfs2(,);
for(int i=;i<=n;i++)ans=max(ans,(__int128)dn[i]*up[i]);
print(ans);
}