51nod_1677:treecnt

时间:2023-01-11 04:12:10

题目是求一棵n节点树中对于C(n,k)颗子树,每棵子树为在n个节点中选不同的k个节点作为树的边界点,这样的所有子树共包含多少条边。

问题可以转化一下,对每一条边,不同的子树中可能包含可能不包含这条边,显然,只有子树那k个节点在该边的两侧均有分布时该边才被包含在子树中。所有边的被包含次数的和,即为answer。对于一条边的被包含次数,设该边两侧分别有a,b个节点,那么,该边被包含的次数为C(a+b,k)-C(a,k)-C(b,k)(也可以借助母函数函数求C(a,i)*C(b,k-i),i从1到min{a,b,k-1},结果一样)。

//dfs写的太搓了,调了半天才好。。。

题目链接

 #include<bits/stdc++.h>
using namespace std; typedef long long LL;
const LL mod=1e9+;
const LL M=1e5+; LL fac[]; //阶乘
LL inv_of_fac[]; //阶乘的逆元 LL qpow(LL x,LL n)
{
LL ret=;
for(; n; n>>=)
{
if(n&) ret=ret*x%mod;
x=x*x%mod;
}
return ret;
}
void init()
{
fac[]=;
for(int i=; i<=M; i++)
fac[i]=fac[i-]*i%mod;
inv_of_fac[M]=qpow(fac[M],mod-);
for(int i=M-; i>=; i--)
inv_of_fac[i]=inv_of_fac[i+]*(i+)%mod;
}
LL C(LL a,LL b)
{
if(b>a) return ;
if(b==) return ;
return fac[a]*inv_of_fac[b]%mod*inv_of_fac[a-b]%mod;
}
/////////////////////////////////////////////////////////////
vector<int> adj[M];
int vis[M];
LL n,k,ans,du[M],hh;
void init1()
{
ans=;
memset(vis,,sizeof(vis));
memset(du,,sizeof(du));
du[]=n;
hh=C(n,k);
for(int i=; i<=n; i++)
adj[i].clear();
}
LL dfs(int s)
{
if(adj[s].size()==&&s!=) return du[s]=;
if(du[s]&&s!=) return du[s];
vis[s]=;
LL ret,cnt=;
for(int i=; i<adj[s].size(); i++)
{
if(!vis[adj[s][i]])
{
// printf("%d -> %d\n",s,adj[s][i]);
cnt+=dfs(adj[s][i]);
ans=(ans+hh-C(dfs(adj[s][i]),k)-C(n-dfs(adj[s][i]),k))%mod;
}
}
return du[s]=cnt+;
} int main()
{
init();
while(~scanf("%lld%lld",&n,&k))
{
init1();
for(int i=; i<n; i++)
{
LL u,v;
scanf("%d%d",&u,&v);
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs();
// for(int i=1; i<=n; i++)
// printf("%d:%lld=========\n",i,du[i]);
// for(int i=1; i<=n; i++)
// {
// printf("i=%d:\n",i);
// for(int j=0; j<adj[i].size(); j++)
// printf("%d ",adj[i][j]);
// puts("");
// }
printf("%lld\n",(ans+mod)%mod);
}
}

// 2017.8.15 更

回头翻一下之前自己写的博客,发现连个dfs都写这么挫,就算这样居然也有人看。重新改了一下代码贴在下面。

#include<bits/stdc++.h>
using namespace std; typedef long long LL;
const LL mod=1e9+;
const LL M=1e5+; LL fac[M+]; //阶乘
LL inv_of_fac[M+]; //阶乘的逆元 LL qpow(LL x,LL n)
{
LL ret=;
for(; n; n>>=)
{
if(n&) ret=ret*x%mod;
x=x*x%mod;
}
return ret;
}
void init()
{
fac[]=;
for(int i=; i<=M; i++)
fac[i]=fac[i-]*i%mod;
inv_of_fac[M]=qpow(fac[M],mod-);
for(int i=M-; i>=; i--)
inv_of_fac[i]=inv_of_fac[i+]*(i+)%mod;
}
LL C(LL a,LL b)
{
if(b>a) return ;
if(b==) return ;
return fac[a]*inv_of_fac[b]%mod*inv_of_fac[a-b]%mod;
}
/////////////////////////////////////////////////////////////
vector<int> adj[M];
LL n,k,ans,hh;
void init1()
{
ans=;
hh=C(n,k);
for(int i=; i<=n; i++)
adj[i].clear();
} LL dfs(int s,int pre)
{
LL ret=;
for(int i=; i<adj[s].size(); i++)
{
if(adj[s][i]==pre) continue;
LL t=dfs(adj[s][i],s);
ret+=t;
ans=(ans+hh-C(t,k)-C(n-t,k))%mod;
}
return ret;
} int main()
{
init();
while(~scanf("%lld%lld",&n,&k))
{
init1();
for(int i=; i<n; i++)
{
LL u,v;
scanf("%d%d",&u,&v);
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(,-);
printf("%lld\n",(ans+mod)%mod);
}
}

相关文章