http://acm.hdu.edu.cn/showproblem.php?pid=4635
我们把缩点后的新图(实际编码中可以不建新图 只是为了概念上好理解)中的每一个点都赋一个值
表示是由多少个点缩成的 我们需要找所有端点 也可能出发点(只有出度) 也可能是结束点 (只有入度)
这个端点和外界(其它所有点)的联通性是单向的(只入或只出) 也可能没有联通
在保持这个端点与外界的单向联通性的情况下 任意加边
所以 当端点的值越小(包含点越少) 结果越优
代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<vector>
#include<list>
#include<stack>
#include<queue>
using namespace std; typedef pair<int,int> pp;
typedef long long ll;
const int N=100005;
const int M=100005;
int head[N],I;
struct node
{
int j,next;
}edge[M];
int low[N],dfn[N],f[N],deep;
bool in[N],visited[N];
stack<int>st;
pp p[M];
void add(int i,int j)
{
edge[I].j=j;
edge[I].next=head[i];
head[i]=I++;
}
bool ok(vector<int>& vt)
{
for(unsigned int i=0;i<vt.size();++i)
{
int x=vt[i];
for(int t=head[x];t!=-1;t=edge[t].next)
{
int y=edge[t].j;
if(f[x]!=f[y])
return false;
}
}
return true;
}
void tarjan(int x,int &M)
{
visited[x]=true;
in[x]=true;
st.push(x);
low[x]=dfn[x]=deep++;
for(int t=head[x];t!=-1;t=edge[t].next)
{
int j=edge[t].j;
if(visited[j]==false)
{
tarjan(j,M);
low[x]=min(low[x],low[j]); }else if(in[j]==true)
{
low[x]=min(low[x],dfn[j]);
}
}
if(low[x]==dfn[x])
{
vector<int>vt; int tmp=1;
while(st.top()!=x)
{
int k=st.top(); st.pop();
vt.push_back(k);
in[k]=false;
f[k]=x;
++tmp;
} int k=st.top(); st.pop();
vt.push_back(k);
in[k]=false;
f[k]=x;
if(ok(vt))
{
M=min(M,tmp);
}
}
}
void init(int n,int m)
{
memset(head,-1,sizeof(head));
I=0;
for(int i=0;i<m;++i)
add(p[i].first,p[i].second);
}
int solve(int n,int m)
{
init(n,m);
while(!st.empty()) st.pop();
for(int i=1;i<=n;++i)
{f[i]=i;}
memset(in,false,sizeof(in));
memset(visited,false,sizeof(visited));
deep=0;
int k=n+1;
for(int i=1;i<=n;++i)
if(!visited[i])
tarjan(i,k);
return k;
}
int main()
{
//freopen("data.in","r",stdin);
int T;
scanf("%d",&T);
for(int ca=1;ca<=T;++ca)
{
printf("Case %d: ",ca);
int n,m;
scanf("%d %d",&n,&m);
for(int i=0;i<m;++i)
scanf("%d %d",&p[i].first,&p[i].second);
int k=solve(n,m);
for(int i=0;i<m;++i)
swap(p[i].first,p[i].second);
k=min(solve(n,m),k);
if(k==n)
{cout<<"-1"<<endl;continue;}
ll ans=0;
ans=(ll)(n)*(ll)(n-1);
ans-=m;
ans-=(ll)(k)*(ll)(n-k);
cout<<ans<<endl;
}
return 0;
}