JZOJ5602. 【NOI2018模拟3.26】Cti

时间:2022-12-16 15:03:49

Description

有一个 n × m 的地图, 地图上的每一个位置可以是空地, 炮塔或是敌人. 你需要操纵炮塔消灭敌人.
对于每个炮塔都有一个它可以瞄准的方向, 你需要在它的瞄准方向上确定一个它的攻击位置,当然也可以不进行攻击. 一旦一个位置被攻击, 则在这个位置上的所有敌人都会被消灭.
保证对于任意一个炮塔, 它所有可能的攻击位置上不存在另外一个炮塔.
定义炮弹的运行轨迹为炮弹的起点和终点覆盖的区域. 你需要求出一种方案, 使得没有两条炮弹轨迹相交.

Input

第一行两个整数 n,m.
接下来 n 行, 每行 m 个整数, 0 表示空地, −1,−2,−3,−4 分别表示瞄准上下左右的炮塔, 正整
数 p 表示表示此位置有 p 个敌人.

Output

一行一个整数表示答案.

Sample Input

输入1:
3 2
0 9
-4 3
0 -1

输入2:
4 5
0 0 -2 0 0
-4 0 5 4 0
0 -4 3 0 6
9 0 0 -1 0

Sample Output

输出1:
9

输出2:
12

Data Constraint

对于前 20% 的数据, n,m ≤ 5;
对于另 20% 的数据, 朝向上下的炮塔至多有 2 个;
对于另 20% 的数据, 至多有 6 个炮塔;
对于 100% 的数据, 1 ≤ n,m ≤ 50, 每个位置的敌人数量 < 1000.

题解

先不考虑相交的限制,
很显然贪心最优解,
每个炮塔都选择对自己方向上的最多敌人处进行攻击。

现在考虑相交的情况,
先将每一个点拆成两个点,
分别表示竖直方向与水平方向,
中间的边权为正无穷。

超级源向所有竖直方向的炮塔连一条正无穷的边,
同理,所有水平方向的炮塔都向超级汇来一条正无穷的边。

如果超级源与超级汇连通,就是说出现了相交的情况。
割掉一条边的意义就是将原本某个炮塔攻击最大值的炮塔的攻击目标改为割掉这条边的位置,
权值就是最大值减去当前点的权值。

最后的答案就是所有最大值的和减去最小割。

code

#include <queue>
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string.h>
#include <cmath>
#include <math.h>
#include <time.h>
#define ll long long
#define N 100003
#define M 53
#define db double
#define P putchar
#define G getchar
#define inf 998244353
using namespace std;
char ch;
void read(int &n)
{
    n=0;
    ch=G();
    while((ch<'0' || ch>'9') && ch!='-')ch=G();
    ll w=1;
    if(ch=='-')w=-1,ch=G();
    while('0'<=ch && ch<='9')n=(n<<3)+(n<<1)+ch-'0',ch=G();
    n*=w;
}

int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
ll abs(ll x){return x<0?-x:x;}
ll sqr(ll x){return x*x;}
void write(ll x){if(x>9) write(x/10);P(x%10+'0');}

int nxt[N*2],to[N*2],v[N*2],last[N],cur[N],tot;
int q[N],h[N],S,T,ans;
int n,m,t,sum,a[M][M],mx,pos,x;

bool bfs()
{
    int head=0,tail=1;
    for(int i=0;i<=T;i++)h[i]=-1;
    q[0]=S;h[S]=0;
    while(head!=tail)
    {
        int now=q[head];head++;
        for(int i=last[now];i;i=nxt[i])
            if(v[i] && h[to[i]]==-1)
            {
                h[to[i]]=h[now]+1;
                q[tail++]=to[i];
            }
    }
    return h[T]!=-1;
}

int dfs(int x,int f)
{
    if(x==T)return f;
    int w,used=0;
    for(int i=cur[x];i;i=nxt[i])
        if(h[to[i]]==h[x]+1)
        {
            w=f-used;
            w=dfs(to[i],min(w,v[i]));
            v[i]-=w;v[i^1]+=w;
            if(v[i])cur[x]=i;
            used+=w;
            if(used==f)return f;
        }
    if(!used)h[x]=-1;
    return used;
}

void dinic()
{
    while(bfs())
    {
        for(int i=0;i<=T;i++)
            cur[i]=last[i];
        ans+=dfs(S,inf);
    }
}

void lb(int x,int y,int z)
{
    nxt[++tot]=last[x];
    to[tot]=y;
    v[tot]=z;
    last[x]=tot;
}

void ins(int x,int y,int z)
{
    lb(x,y,z);
    lb(y,x,0);
}

int get(int x,int y)
{
    return x*m-m+y;
}

int main()
{
    freopen("cti.in","r",stdin);
    freopen("cti.out","w",stdout);

    read(n);read(m);
    S=2*n*m+1;T=S+1;tot=1;

    for(int i=1;i<=n;i++)
        for(int j=1;j<=m;j++)
            read(a[i][j]),x=get(i,j),ins(x,n*m+x,inf);

    for(int i=1;i<=n;i++)
        for(int j=1;j<=m;j++)
            if(a[i][j]==-1)     
            {
                ins(S,get(i,j),inf);
                mx=0;
                for(int p=i-1;p;p--)
                {
                    if(a[p][j]<0)break;
                    if(a[p][j]>=mx)mx=a[p][j],pos=p;
                }
                for(int p=i-1;p>=pos;p--)
                    ins(get(p+1,j),get(p,j),mx-max(a[p+1][j],0));
                sum+=mx;
            }
            else
            if(a[i][j]==-2)     
            {
                ins(S,get(i,j),inf);
                mx=0;
                for(int p=i+1;p<=n;p++)
                {
                    if(a[p][j]<0)break;
                    if(a[p][j]>=mx)mx=a[p][j],pos=p;
                }
                for(int p=i+1;p<=pos;p++)
                    ins(get(p-1,j),get(p,j),mx-max(a[p-1][j],0));
                sum+=mx;
            }
            else
            if(a[i][j]==-3)     
            {
                ins(get(i,j)+n*m,T,inf);
                mx=0;
                for(int p=j-1;p;p--)
                {
                    if(a[i][p]<0)break;
                    if(a[i][p]>=mx)mx=a[i][p],pos=p;
                }
                for(int p=j-1;p>=pos;p--)
                    ins(get(i,p)+n*m,get(i,p+1)+n*m,mx-max(a[i][p+1],0));
                sum+=mx;
            }
            else
            if(a[i][j]==-4)     
            {
                ins(get(i,j)+n*m,T,inf);
                mx=0;
                for(int p=j+1;p<=m;p++)
                {
                    if(a[i][p]<0)break;
                    if(a[i][p]>=mx)mx=a[i][p],pos=p;
                }
                for(int p=j+1;p<=pos;p++)
                    ins(get(i,p)+n*m,get(i,p-1)+n*m,mx-max(a[i][p-1],0));
                sum+=mx;
            }

    dinic();
    write(sum-ans);

    return 0;
}