poj 3422 Kaka's Matrix Travels 费用流

时间:2021-10-11 05:35:44

题目链接

给一个n*n的矩阵, 从左上角出发, 走到右下角, 然后在返回左上角,这样算两次。 一共重复k次, 每个格子有值, 问能够取得的最大值是多少, 一个格子的值只能取一次, 取完后变为0。

费用流第一题, 将每个格子拆为两个点, u向u'连一条容量为1, 费用为格子的值的边, u向u'再连一条容量为k-1, 费用为0的边。u'向他右边和下边的格子连一条容量为k, 费用为0的边, 跑一遍费用流就可以。

 #include <iostream>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>
#include <string>
#include <queue>
using namespace std;
#define pb(x) push_back(x)
#define ll long long
#define mk(x, y) make_pair(x, y)
#define lson l, m, rt<<1
#define mem(a) memset(a, 0, sizeof(a))
#define rson m+1, r, rt<<1|1
#define mem1(a) memset(a, -1, sizeof(a))
#define mem2(a) memset(a, 0x3f, sizeof(a))
#define rep(i, a, n) for(int i = a; i<n; i++)
#define ull unsigned long long
typedef pair<int, int> pll;
const double PI = acos(-1.0);
const double eps = 1e-;
const int mod = 1e9+;
const int inf = ;
const int dir[][] = { {, }, {, }, {, -}, {, } };
const int maxn = 2e5+;
int num, head[maxn*], s, t, n, k, nn, dis[maxn], flow, cost, cnt, cap[maxn], q[maxn], cur[maxn], vis[maxn];
struct node
{
int to, nextt, c, w;
node(){}
node(int to, int nextt, int c, int w):to(to), nextt(nextt), c(c), w(w) {}
}e[maxn*];
int spfa() {
int st, ed;
st = ed = ;
mem2(dis);
++cnt;
dis[s] = ;
cap[s] = inf;
cur[s] = -;
q[ed++] = s;
while(st<ed) {
int u = q[st++];
vis[u] = cnt-;
for(int i = head[u]; ~i; i = e[i].nextt) {
int v = e[i].to, c = e[i].c, w = e[i].w;
if(c && dis[v]>dis[u]+w) {
dis[v] = dis[u]+w;
cap[v] = min(c, cap[u]);
cur[v] = i;
if(vis[v] != cnt) {
vis[v] = cnt;
q[ed++] = v;
}
}
}
}
if(dis[t] == inf)
return ;
cost += dis[t]*cap[t];
flow += cap[t];
for(int i = cur[t]; ~i; i = cur[e[i^].to]) {
e[i].c -= cap[t];
e[i^].c += cap[t];
}
return ;
}
int mcmf() {
flow = cost = ;
while(spfa())
;
return cost;
}
void add(int u, int v, int c, int val) {
e[num] = node(v, head[u], c, -val); head[u] = num++;
e[num] = node(u, head[v], , val); head[v] = num++;
}
void input() {
int x;
for(int i = ; i<n; i++) {
for(int j = ; j<n; j++) {
scanf("%d", &x);
add(i*n+j, i*n+j+nn, , x);
add(i*n+j, i*n+j+nn, k-, );
}
}
add(s, , k, );
add(*nn-, t, k, );
for(int i = ; i<n; i++) {
for(int j = ; j<n; j++) {
for(int k1 = ; k1<; k1++) {
int x = dir[k1][]+i;
int y = dir[k1][]+j;
if(x>=&&x<n&&y>=&&y<n) {
add(i*n+j+nn, x*n+y, k, );
}
}
}
}
}
void init() {
mem1(head);
num = cnt = ;
mem(vis);
}
int main()
{
while(~scanf("%d%d", &n, &k)) {
init();
nn = n*n;
s = *nn, t = s+;
input();
int ans = mcmf();
printf("%d\n", -ans);
}
return ;
}