Power Station POJ 4045

时间:2024-09-04 10:05:08

题意:给你一棵树,让你求一点,使该点到其余各点的距离之和最小。如果这样的点有多个,则按升序依次输出。

树型dp

#include <cstdio>
#include <cstring>
#include <vector>
#include <set>
using namespace std;
const int maxn=50010;
typedef __int64 LL;
vector<int>tree[maxn];// to save the relation
LL f[maxn],g[maxn],dp[maxn];//f[u] u as root to all his sons' distance g[u]the number of u'sons
set<int> myqueue;
void dfs(int u,int pa){
/* 以1为根,所有的子树到他们的子节点的和 */
    if(tree[u].size() == 1 && u != 1){
        g[u]=1;
        f[u]=0;
        return;
    }
    for(int v=0;v < tree[u].size();v++){
        if(tree[u][v]!=pa){
            dfs(tree[u][v],u);
            g[u]+=g[tree[u][v]];//the son's sons' number
            f[u]+=f[tree[u][v]]+g[tree[u][v]];
        }
    }
    g[u]++;// himself
}
void dfs2(int u,int pa){// to sum the way from his father
//再考虑从父节点来的
    if(tree[u].size()==1 && u!=1){
        dp[u]=dp[pa]+g[1]-(g[u]<<1);
        return;
    }
    for(int v=0;v<tree[u].size();v++){
        if(tree[u][v]!= pa){
//到某节点的距离的和=该节点子树的距离和(在dfs1中获得)+从父亲那一支子树获得的和(此时,父节点那一支看成子树)。dp[儿子]=dp[父节点]-dp[儿子]-g[儿子](儿子到父节点这条路被减了子树的子节点数的次数)    + dp[儿子]+(g[1]-g[儿子])(父树上的所有节点)
            dp[tree[u][v]]=dp[u]+g[1]-(g[tree[u][v]]<<1);
            dfs2(tree[u][v],u);//先算了之后再跑子树
        }
    }
}
int main(){
    int t;
    int n,I,R,a,b;
    scanf("%d",&t);
    LL mmin;
    while(t--){
        scanf("%d%d%d",&n,&I,&R);

        for(int i=0;i<=n;i++){
            tree[i].clear();
        }
        for(int i=2;i<=n;i++){
            scanf("%d%d",&a,&b);
            tree[a].push_back(b);
            tree[b].push_back(a);
        }
//        for(int i=0;i<=n;i++){
//            for(int j=0;j<tree[i].size();j++){
//                printf("%d ",tree[i][j]);
//            }
//            printf("\n");
//        }
        memset(f,0,sizeof(f));
        memset(g,0,sizeof(g));
        dfs(1,-1);
        dp[1]=f[1];
        dfs2(1,-1);
        mmin=dp[1];
        myqueue.clear();
        myqueue.insert(1);
        for(int i=2;i<=n;i++){
            if(dp[i]<mmin){
                mmin=dp[i];
                myqueue.clear();
                myqueue.insert(i);
            }
            if(dp[i]==mmin) myqueue.insert(i);
        }
        //warning : output long long should be I64d
        printf("%I64d\n",I*I*R*mmin);//use set not num but the op
            for(set<int>::iterator it=myqueue.begin();it!=myqueue.end();++it){
                printf("%d ",*it);
            }
        printf("\n\n");
    }
}