poj 4045 (树形DP)

时间:2022-01-01 18:24:20

先选一点为根节点找出所有父节点i到下面所有点距离和dp[i],该父节点下面有多少个点Node[i]。

然后求出所有节点的所有非子节点到该点的距离dp1[v]+=(dp1[u]+(dp[u]-dp[v]-Node[v]-1)+n-Node[v]-1)

dp[u]-dp[v]-Node[v]-1:u的子节点中除了v这一部分子节点到u的距离

n-Node[v]-1:非v的字节点的个数

#include<stdio.h>
#include<string.h>
#define N 50002
#define inf 0x3fffffff
int head[N],num,vis[N],dp[N],Node[N],dp1[N],n,I,R;
struct edge
{
int st,ed,next;
}E[N*2];
void addedge(int x,int y)
{
E[num].st=x;
E[num].ed=y;
E[num].next=head[x];
head[x]=num++;
}
void dfs(int u)
{
vis[u]=1;
int i,v;
for(i=head[u];i!=-1;i=E[i].next)
{
v=E[i].ed;
if(vis[v]==1)continue;
dfs(v);
dp[u]+=(dp[v]+Node[v]+1);//所有子节点到到父节点的距离
Node[u]+=(Node[v]+1);//子节点个数
}
}
long long mm;
void dfs1(int u)
{
int i,v;
vis[u]=1;
for(i=head[u];i!=-1;i=E[i].next)
{
v=E[i].ed;
if(vis[v]==1)continue;
dp1[v]+=(dp1[u]+(dp[u]-dp[v]-Node[v]-1)+n-Node[v]-1);//除了子节点外所有节点到该点的距离
dfs1(v);
}
if(mm>dp[u]+dp1[u])
mm=dp[u]+dp1[u];
}
int main()
{
int i,x,y,t;
scanf("%d",&t);
while(t--)
{
scanf("%d%d%d",&n,&I,&R);
memset(head,-1,sizeof(head));
num=0;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
addedge(x,y);
addedge(y,x);
}
memset(dp,0,sizeof(dp));
memset(dp1,0,sizeof(dp1));
memset(Node,0,sizeof(Node));
memset(vis,0,sizeof(vis));
mm=inf;
dfs(1);
memset(vis,0,sizeof(vis));
dfs1(1);
printf("%lld\n",I*I*R*mm);
for(i=1;i<=n;i++)
{
if(dp[i]+dp1[i]==mm)
printf("%d ",i);
}
printf("\n\n");
}
return 0;
}