文章标题 coderforces 609E : Minimum spanning tree for each edge (MST+LCA)

时间:2021-11-24 20:24:41

Minimum spanning tree for each edge

题意:就是有一个无向图n个点,m条边,然后对于第i条边,然后要们求出包含第i条边是最小生成树的权值是多少,然后输出m个值。
分析:首先先求出最小生成树,然后通过这个,然后对于每一条边(u,v),只需要求出u-> v这条路径上的最大边,用这条边替换(u,v)这条边即可,问题就转化成求u到v路径上的最大边,而求最大边可以用LCA中的倍增方式求出来,在倍增的过程中保存u->lca的最大边,和v->lca的最大边。
代码:

#include<iostream>
#include<string>
#include<cstdio>
#include<cstring>
#include<vector>
#include<math.h>
#include<map>
#include<queue> 
#include<algorithm>
using namespace std;
const int inf = 0x3f3f3f3f;
typedef pair<int,int> pii;
const int maxn=2e5+10;
typedef long long ll;

int n,m;
ll ans;//最小生成树的结果

//lca
struct node {
    int to,nex,w;
}E[maxn*2];
int tot,head[maxn];
void init(){
    tot=0;
    memset (head,-1,sizeof (head));
}
void addedge(int u,int v,int w){
    E[tot]=node{v,head[u],w};
    head[u]=tot++; 
}
const int LOG=20;
int par[maxn][LOG],Maxedge[maxn][LOG],dep[maxn];
void dfs(int u,int pre,int depth){
    dep[u]=depth;
    par[u][0]=pre;
    for (int i=head[u];i!=-1;i=E[i].nex){
        int v=E[i].to;
        if (v==pre){
            Maxedge[u][0]=E[i].w;
            continue;
        }
        dfs(v,u,depth+1);
    }
}
void work(){
    dfs(1,-1,1);
    for (int i=1;i<LOG;i++){
        for (int j=1;j<=n;j++){
            par[j][i]=par[par[j][i-1]][i-1];    
            Maxedge[j][i]=max(Maxedge[j][i-1],Maxedge[par[j][i-1]][i-1]);
        }
    }
}

int lca(int u,int v){
    if (dep[u]<dep[v])swap(u,v);
    int d=dep[u]-dep[v];
    int res=0;
    for (int i=0;i<LOG;i++){
        if (d&(1<<i)){
            res=max(res,Maxedge[u][i]);
            u=par[u][i];
        }
    }
    if (u==v)return res;
    for (int i=LOG-1;i>=0;i--){
        if (par[u][i]!=par[v][i]){
            res=max(res,Maxedge[u][i]);
            res=max(res,Maxedge[v][i]);
            u=par[u][i];v=par[v][i];
        }
    }
    return max(res,max(Maxedge[u][0],Maxedge[v][0]));
}

//MST
struct Edge{
    int u,v;
    ll w;
    int id;
    bool operator <(const Edge &t)const {
        return w<t.w;
    } 
}edge[maxn]; 
int fa[maxn];
int find(int x){
    return fa[x]==x?x:fa[x]=find(fa[x]);
}
void MST(){
    ans=0;
    for (int i=1;i<=n;i++)fa[i]=i;
    sort(edge+1,edge+1+m);
    int cnt=0;
    for (int i=1;i<=m;i++){
        int u=edge[i].u,v=edge[i].v;
        ll w=edge[i].w;
        int fu=find(u),fv=find(v);
        if (fu!=fv){
            fa[fu]=fv;
            addedge(u,v,w);//加边 
            addedge(v,u,w);
            ans+=w;
            cnt++;
        }
        if (cnt==n-1)break;
    }
}

ll a[maxn];
int main ()
{
    while (scanf ("%d%d",&n,&m)!=EOF){
        init();
        for (int i=1;i<=m;i++){
            scanf ("%d%d%lld",&edge[i].u,&edge[i].v,&edge[i].w);
            edge[i].id=i;
        } 
        MST();
        work();
        for (int i=1;i<=m;i++){
            a[edge[i].id]=ans-lca(edge[i].u,edge[i].v)+edge[i].w;
        }
        for (int i=1;i<=m;i++){
            printf ("%lld\n",a[i]);
        }
    }
    return 0;
}