HDU 2255 奔小康赚大钱 (二分图:KM算法)

时间:2022-08-03 06:27:23

题意:

中文题不解释

要点:

KM算法是求完备匹配下的最大权匹配: 在一个二分图内,左顶点为X,右顶点为Y,现对于每组左右连接X[i]Y[j]有权w[i][j],求一种匹配使得所有w[i][j]的和最大。注意完备匹配定义:|X|=|Y|=匹配数。这算法还是比较难的,证明我还是半懂不懂的,具体流程还是可以的。基本上就是利用增广路,不断修改点标,找可行边什么的。

参考博客:点击打开链接

这题就是个裸题,不过我一开始一直TLE,换了复杂度为O(N^3)的还是很多问题,后来发现模板都看错了。真挺难的这算法,什么时候证明一下,感觉没完全搞懂。


17132269 2016-05-12 15:41:17 Accepted 2255 780MS 2108K 2008 B C++ seasonal
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N = 310;
const int inf = 0x3f3f3f3f;

int visx[N], visy[N];
int lx[N], ly[N], w[N][N];
int slack[N],girl[N];
int nx,ny;

int find(int x)
{
	visx[x] = 1;
	for (int y = 1; y <= ny; y++)
	{
		if (visy[y])
			continue;
		int t = lx[x] + ly[y] - w[x][y];
		if (t==0)
		{
			visy[y] = 1;
			if (girl[y] == -1 || find(girl[y]))
			{
				girl[y] = x;
				return 1;
			}		
		}
		else if (slack[y] > t)//如果不在相等子图中就取最小的
			slack[y] = t;	
	}
	return 0;
}

int km()
{
	int i,j;
	memset(ly, 0, sizeof(ly));
	memset(girl, -1, sizeof(girl));
	for (i = 1; i <= nx; i++)
	{
		lx[i] = -inf;
		for (j = 1; j <= ny; j++)
			if (lx[i] < w[i][j])
				lx[i] = w[i][j];//lx先取连接中权值最大的
	}
	for (i = 1; i <= nx; i++)
	{
		for (j = 1; j <= ny; j++)
			slack[j] = inf;
		while (1)
		{
			memset(visx, 0, sizeof(visx));
			memset(visy, 0, sizeof(visy));
			if (find(i))				//如果成功说明增广成功,进入下一个点的增广
				break;					//若失败(没有找到增广轨),则需要改变一些点的标号,使得图中可行边的数量增加。        
			int d = inf;				
            for (j = 1; j <= ny; j++)	
				if (!visy[j] && d>slack[j])
					d = slack[j];		//取没有遍历到的Y点的(也就是连接但是相加不为权值)的slack中的最小值作为d
			for (j = 1; j <= nx; j++)
				if (visx[j])
					lx[j] -= d;			//将所有在增广轨中(就是在增广过程中遍历到)的X方点的标号全部减去一个常数d,
			for (j = 1; j <= ny; j++)
				if (visy[j])
					ly[j] += d;			//所有在增广轨中的Y方点的标号全部加上一个常数d
				else
					slack[j] -= d;		//不在交错树中的Y对应的slack也要-d
		}
	}
	int sum = 0;
	for (i = 1; i <= ny; i++)
		if (girl[i] != -1)
			sum += w[girl[i]][i];
	return sum;
}

int main()
{
	int i,j,n;
	while (~scanf("%d", &n))
	{
		nx = ny = n;
		for (i = 1; i <= n; i++)
			for (j = 1; j <= n; j++)
				scanf("%d", &w[i][j]);
		int ans = km();
		printf("%d\n", ans);
	}
	return 0;
}