题意:
给一个DAG(有向无环图) 有q次操作, 每次操作把一个点变成黑色或者变回来
(这些点初始都是白色的), 问每次操作后这个图中起点到终点有路径且这条路径
上面的点都是白色的点有多少对?
思路:
这个题也可以有两种做法,首先我们最基本的暴力去做肯定超时,复杂度O(n*m*q),
官方题解给出的做法就是,我们可以记录f[x][y]为x到y的路径条数,然后维护这个路径.
当要将v变为黑点时 f[x][y]-=f[x][v]*f[v][y],反之变白就是+
。
最后O(n2)枚举任意两点之间的路径条数,>0就是一对点.
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long LL;
const int maxn = 350;
LL dp[maxn][maxn], vis[maxn];
int N, M, Q;
void init()
{
for(int i = 1; i <= N; i++)
{
vis[i] = 0;
for(int j = 1; j <= N; j++)
dp[i][j] = 0;
}
}
int main ()
{
while(~scanf("%d %d %d", &N, &M, &Q))
{
init();
for(int i = 1; i <= M; i++)
{
int u, v;
scanf("%d %d", &u, &v);
dp[u][v]++;
}
for(int k = 1; k <= N; k++ )
for(int st = 1; st <= N; st++)
for(int ed = 1; ed <= N; ed++)
dp[st][ed] += dp[st][k] * dp[k][ed];
for(int i = 1; i <= Q; i++)
{
int u;
scanf("%d", &u);
if(vis[u] == 0)
{
vis[u] = 1;
for(int st = 1; st <=N; st++)
for(int ed = 1; ed <= N; ed++)
dp[st][ed] -= dp[st][u] * dp[u][ed];
}
else
{
vis[u] = 0;
for(int st = 1; st <=N; st++)
for(int ed = 1; ed <= N; ed++)
dp[st][ed] += dp[st][u] * dp[u][ed];
}
int ans = 0;
for(int st = 1; st <= N; st++)
{
if(vis[st]) continue;
for(int ed = 1; ed <= N; ed++)
{
if(vis[ed]) continue;
if(dp[st][ed] > 0) ans++;
}
}
printf("%d\n", ans);
}
}
return 0;
}
另一种方法就是我们要想办法优化我们的暴力,这里用到一个很巧妙的bitset,
我们假设bitset b[i][j]表示第j个点有一条到i点的边.
然后在过程中跑一遍拓扑排序,如果当前u到v有边,那么mp[v]=mp[v] | mp[u]
即将所有可以到达u的点都可以到达v.
注意这个过程中只考虑白点不考虑黑点,所以还要将黑点标记.
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int maxn=1e5+10;
int n,m,q;
vector<int>vt[333];
int in[333],del[333],book[333];
int mid[333];
void solve()
{
bitset<333>mp[333];
queue<int>Q;
for(int i=1;i<=n;i++)
{
if(in[i]==0)
Q.push(i);
mid[i]=in[i];
mp[i][i]=1;
}
while(!Q.empty())
{
int x=Q.front();
Q.pop();
for(int i=0;i<vt[x].size();i++)
{
int y=vt[x][i];
if(del[x]==0&&del[y]==0)
{
mp[y]=mp[x]|mp[y];
}
mid[y]--;
if(mid[y]==0)
Q.push(y);
}
}
int ans=0;
for(int i=1;i<=n;i++)
{
if(del[i]==0)
ans+=mp[i].count()-1;
}
printf("%d\n",ans);
}
int main(){
while(~scanf("%d %d %d",&n,&m,&q))
{
memset(in,0,sizeof(in));
memset(del,0,sizeof(del));
for(int i=1;i<=n;i++)
vt[i].clear();
int a,b;
for(int i=1;i<=m;i++)
{
scanf("%d %d",&a,&b);
vt[a].push_back(b);
in[b]++;
}
while(q--)
{
scanf("%d",&a);
del[a]=1-del[a];
solve();
}
}
return 0;
}