【BZOJ 4011】[HNOI2015]落忆枫音

时间:2022-12-16 16:08:41

题目描述

给出一个 n 个节点 m 条边的有向无环图,外加一条有向边 (x,y) ,求以 1 为根的生成树数量。(保证原 m 条边中不指向 1 号节点)

题目解析

GG,考试时看了一眼第一发现没有思路,于是果断暴力,开始打第二题的数据结构,谁知道我的代码那么丑,,,本来的40分只有10分,第一题暴力也gg了,后来想了很久也没想出来,只有找hcx,才发现自己智障了。
首先有向无环图的生成树数量为除根节点以外的入度乘积,由于本题已明确 1 为根,那么答案就是 Πni=2rdi 。但是这里又有一条边 (x,y) ,若加上这条边形成了环,那么方案数就可能包含环的情况,哪么怎么统计包含环的方案数呢?对于一条环路,包含他的方案数为 Πni=2rdiΠirdi 。事实上,每一个环都会通过 (x,y) ,最多只会有一条环路存在。所以我们可以从 x y 分别向正反遍历,找到同时覆盖的点,其必然在环上,对于在环上的点, dpi 表示从 y 到了 i 号节点时产生的方案数,那么有 dpi=(Σj,access(i,j)=truedpj)×inv(rdi) inv 表示逆元,那么 ans=Πni=2rdiΠni=2rdi×dpx

代码

#include<iostream>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<cstdio>
#include<queue>
using namespace std;

#define MAXN 100000
#define MAXM 200000
#define INF 0x3f3f3f3f
typedef long long int LL;
const LL MOD = 1000000007ll;

template<class T>
void Read(T &x){
    x=0;char c=getchar();bool flag=0;
    while(c<'0'||'9'<c){if(c=='-')flag=1;c=getchar();}
    while('0'<=c&&c<='9'){x=x*10+c-'0';c=getchar();}
    if(flag)x=-x;
}

int n,m,x,y;
int rd[MAXN+10];

struct node{
    int v;
    node *nxt;
}*adj1[MAXN+10],*adj2[MAXN+10],Edges[MAXM*2+100],*New=Edges;

void addedge(int u,int v){
    node *p=++New;
    p->v=v;
    p->nxt=adj1[u];
    adj1[u]=p;

    p=++New;
    p->v=u;
    p->nxt=adj2[v];
    adj2[v]=p;
}

LL ksm(LL a,LL p,LL mod){
    LL rn=1;
    while(p){
        if(p&1)rn=rn*a%mod;
        p>>=1;
        a=a*a%mod;
    }
    return rn;
}

LL getinv(LL x,LL mod){
    return ksm(x,mod-2,mod);
}

bool mark1[MAXN+10],mark2[MAXN+10];
void dfs1(int x){//正
    mark1[x]=true;
    for(node *p=adj1[x];p!=NULL;p=p->nxt)
        if(!mark1[p->v])dfs1(p->v);
}
void dfs2(int x){//反
    mark2[x]=true;
    for(node *p=adj2[x];p!=NULL;p=p->nxt)
        if(!mark2[p->v])dfs2(p->v);
}

LL dp[MAXN+10];
int tmp[MAXN+10];
queue<int>que;
LL getnum(int s,int t){
    for(int i=1;i<=n;++i)tmp[i]=rd[i];
    for(int i=1;i<=n;++i)
        if(!mark1[i]||!mark2[i]){
            for(node *p=adj1[i];p!=NULL;p=p->nxt)
                if(mark1[p->v]&&mark2[p->v])--tmp[p->v];
        }
    while(!que.empty())que.pop();

    que.push(s);
    dp[s]=1;

    int now;
    while(!que.empty()){
        now=que.front();que.pop();
        dp[now]=dp[now]*getinv(rd[now],MOD)%MOD;

        for(node *p=adj1[now];p!=NULL;p=p->nxt)
            if(mark1[p->v]&&mark2[p->v]){
                dp[p->v]=(dp[p->v]+dp[now])%MOD;
                if(!(--tmp[p->v]))que.push(p->v);
            }
    }

    return dp[t];
}

int main(){
    Read(n),Read(m),Read(x),Read(y);
    ++rd[y];

    int a,b;
    for(int i=1;i<=m;++i){
        Read(a),Read(b);
        ++rd[b];
        addedge(a,b);
    }

    LL ans=1;
    for(int i=2;i<=n;++i)
        ans=ans*rd[i]%MOD;

    if(x==1||y==1){
        printf("%lld\n",ans);
        return 0;
    }

    dfs1(y);
    dfs2(x);

    LL num=getnum(y,x);
    printf("%lld\n",((ans-ans*num%MOD)%MOD+MOD)%MOD);
}
/* 4 5 3 2 1 2 1 4 2 4 4 3 2 3 */