二分图最大权匹配 KM算法 板子

时间:2022-05-05 06:20:21

例题:uoj #80. 二分图最大权匹配

板子:

#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
#define inf 1e12
using namespace std;

int read()
{
    char c; int x;
    while(!((c=getchar())>='0'&&c<='9'));
    x=c-'0';
    while((c=getchar())>='0'&&c<='9') (x*=10)+=c-'0';
    return x;
}
void up(ll &x,const ll &y){if(x<y)x=y;}
void down(ll &x,const ll &y){if(x>y)x=y;}

const int maxn = 510;

int fx[maxn],fy[maxn],pre[maxn];
int n1,n2,n,m;
void aug(const int &x)
{
    if(!x) return ;
    int b=pre[x],c=fx[b];
    fy[x]=b; fx[b]=x;
    aug(c);
}
ll w[maxn][maxn];
bool vx[maxn],vy[maxn];
ll dx[maxn],dy[maxn],slk[maxn];
int q[maxn],tail;
void KM(const int &s)
{
    for(int i=1;i<=n;i++) slk[i]=inf,vx[i]=vy[i]=false,pre[i]=0;
    q[tail=1]=s; int head=1;
    for(;;)
    {
        while(head<=tail)
        {
            const int &x=q[head++]; vx[x]=true;
            for(int y=1;y<=n;y++) if(!vy[y])
            {
                if(dx[x]+dy[y]==w[x][y])
                {
                    pre[y]=x; vy[y]=true;
                    if(!fy[y]) { aug(y); return ; }
                    q[++tail]=fy[y];
                }
                else if(slk[y]>dx[x]+dy[y]-w[x][y])
                {
                    slk[y]=dx[x]+dy[y]-w[x][y];
                    pre[y]=x;
                }
            }
        }
        ll d=inf;
        for(int i=1;i<=n;i++) if(!vy[i]) down(d,slk[i]);
        for(int i=1;i<=n;i++)
        {
            if(vx[i]) dx[i]-=d;
            if(vy[i]) dy[i]+=d;
            else slk[i]-=d;
        }
        for(int i=1;i<=n;i++)if(!slk[i]&&!vy[i])
        {
            vy[i]=true;
            if(!fy[i]) { aug(i); return ; }
            q[++tail]=fy[i];
        }
    }
}

int main()
{
    n1=read(); n2=read(); m=read();
    n=n1>n2?n1:n2;
    for(int i=1;i<=m;i++)
    {
        int x=read(),y=read(); ll c; scanf("%lld",&c);
        w[x][y]=c;
        up(dx[x],c);
    }
    for(int i=1;i<=n;i++) 
        if(!fx[i]) KM(i);

    ll re=0;
    for(int i=1;i<=n;i++) re+=dx[i]+dy[i];
    printf("%lld\n",re);
    for(int i=1;i<=n1;i++)
        printf("%d ",w[i][fx[i]]?fx[i]:0);

    return 0;
}