洛谷P5206 [WC2019] 数树(生成函数+容斥+矩阵树)

时间:2023-01-21 08:26:54

题面

传送门

前置芝士

矩阵树,基本容斥原理,生成函数,多项式\(\exp\)

题解

我也想哭了……orz rqy,orz shadowice

我们设\(T1,T2\)为两棵树,并定义一个权值函数\(w(T1,T2)=y^{n-|T1\cap T2|}\),其中\(|T1\cap T2|\)为两棵树共同拥有的边的数目

显然,\(w(T1,T2)\)就是两棵树在该情况下的方案个数,因为\(T1\cap T2\)后的图中每个连通块只能用同一种颜色,而\(n-|T1\cap T2|\)就是连通块个数

子任务\(0\)就是给出\(T1,T2\),求\(w(T1,T2)\)

子任务\(1\)就是给定\(T1\),对所有的\(T2\)求和

子任务\(2\)就是对所有的\(T1,T2\)求和

子任务\(0\)

就是问有多少条边重合,直接暴力跑一下就行了

子任务\(1\)

\(y^{n-|T1\cap T2|}\)太麻烦了,把它化成\(y^{-|T1\cap T2|}\),最后再把答案乘上一个\(y^n\)就行了

我们现在已经知道了\(T1\),我们假设以下设计的所有集合都是\(T1\)的子集。设一个\(F(S)\)表示\(T1\cap T2=S\)的所有\(w(T1,T2)\)的权值之和,那么答案就是\(\sum_{S} F(S)\)。显然\(F(S)=G(S)y^{-|S|}\),其中\(G(S)\)表示\(T1\cap T2=S\)的\(T2\)的个数

然后我们再设一个\(A(S)=B(S)y^{-|S|}\),其中\(B(S)\)表示含有\(S\)这个边集的树的个数

首先我们很容易写出一个容斥式子

\[G(S)=\sum_{S\subseteq T} B(T)(-1)^{|T|-|S|}
\]

我们在两边同乘上\(y\),可以变成

\[G(S)y^{-|S|}=\sum_{S\subseteq T} B(T)y^{-|T|}(-1)^{|T|-|S|}y^{|T|-|S|}
\]

\[F(S)=\sum_{S\subseteq T} A(T)(-y)^{|T|-|S|}
\]

然后我们再来看一看答案

\[\begin{aligned}
ans
&=\sum_{S}F(S)\\
&=\sum_{S}\sum_{S\subseteq T}A(T)(-y)^{|T|-|S|}\\
&=\sum_{T}A(T)(-y)^{|T|}\sum_{S\subseteq T}(-y)^{-|S|}\\
&=\sum_{T}A(T)(-y)^{|T|}\sum_{i=0}^{|T|}{|T|\choose i}(-y)^{-i}\\
&=\sum_{T}A(T)(-y)^{|T|}\sum_{i=0}^{|T|}{|T|\choose i}(-{1\over y})^i1^{|T|-i}\\
&=\sum_{T}A(T)(-y)^{|T|}(1-{1\over y})^{|T|}\\
&=\sum_{T}C(T)({1\over y}-1)^{|T|}\\
\end{aligned}
\]

然后令\(p={1\over y}-1\),那么式子就变成了

\[ans=\sum_{T}C(T)p^{|T|}
\]

然而它的复杂度仍然是指数级的还是没用,还得继续

考虑\(C(T)\),表示至少含有\(T\)这个边集的树的个数,设这个边集有\(k\)个连通块,第\(i\)个连通块中有\(a_i\)个点,那么可以知道

\[C(T)=n^{k-2}\prod_{i=1}^k a_i
\]

那么我们来手屠矩阵吧

先来考虑一下\(C(T)\)该怎么计算。我们把所有连通块中的点缩到一起,对于两个连通块\(i,j\),在两个点之间连接\(a_ia_j\)条重边。那么这个新的图的矩阵树求出的答案就是\(C(T)\)了

那么这样建出来图的基尔霍夫矩阵应该长这样

\[\left[\begin{matrix}a_{1}(n-a_{1}) & -a_{1}a_{2} & \cdots& -a_{1}a_{k} \\ -a_{2}a_{1} & a_{2}(n-a_{2}) & \cdots & -a_{2}a_{k} \\ \vdots & \vdots & \ddots & \vdots \\ -a_{k}a_{1} & -a_{k}a_{2} & \cdots & a_{k}(n-a_{k}) \end{matrix} \right]
\]

然后我们删去一行一列之后会变成这样

\[\left[\begin{matrix}a_{1}(n-a_{1}) & -a_{1}a_{2} & \cdots& -a_{1}a_{k-1} \\ -a_{2}a_{1} & a_{2}(n-a_{2}) & \cdots & -a_{2}a_{k-1} \\ \vdots & \vdots & \ddots & \vdots \\ -a_{k-1}a_{1} & -a_{k-1}a_{2} & \cdots & a_{k-1}(n-a_{k-1}) \end{matrix} \right]
\]

接下来我们将第 \(i\) 行除去 \(a_{i}\) 这样最终的行列式需要乘上一个 \(\prod_{i=1}^{k-1}a_{i}\) ,矩阵会变成这样

\[\left[\begin{matrix}(n-a_{1}) & -a_{2} & \cdots& -a_{k-1} \\ -a_{1} & (n-a_{2}) & \cdots & -a_{k-1} \\ \vdots & \vdots & \ddots & \vdots \\ -a_{1} & -a_{2} & \cdots & (n-a_{k-1}) \end{matrix} \right]
\]

然后我们将第2列到第k-1列加到第1列上,会得到

\[\left[\begin{matrix}a_{k} & -a_{2} & \cdots& -a_{k-1} \\ a_{k} & (n-a_{2}) & \cdots & -a_{k-1} \\ \vdots & \vdots & \ddots & \vdots \\ a_{k} & -a_{2} & \cdots & (n-a_{k-1}) \end{matrix} \right]
\]

接下来我们用第1行去减其他的行,会得到

\[\left[\begin{matrix}a_{k} & -a_{2} & \cdots& -a_{k-1} \\ 0 & n & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & n \end{matrix} \right]
\]

这样矩阵就被我们削成了一个上三角阵,它的行列式是 \(n^{k-2}a_{k}\) 乘上 \(\prod_{i=1}^{k-1}a_{i}\) 就是

\[n^{k-2}\prod_{i=1}^{k}a_{i}
\]

这样我们就证明了我们的式子是正确的

那么式子可以继续展开

\[\begin{aligned}
ans
&=\sum_{T}C(T)p^{|T|}\\
&=\sum_{T}n^{n-|T|-2}p^{|T|}\prod_{i=1}^{n-|T|}a_i\\
&={p^n\over n^2}\sum_{T}\prod_{i=1}^{n-|T|}{a_in\over p}\\
\end{aligned}
\]

令\(k={n\over p}\),问题可以转化成如下形式:一个连通块权值为\(k\)乘上这个连通块的大小,一个边集\(T\)的权值等于其中所有连通块的权值之积,求所有边集的权值之和

设\(f(i,j)\)表示考虑到\(i\)这棵子树,其中\(i\)所在的连通块大小为\(j\),总的答案是多少。不过因为这里的权值并不包含\(i\)所在的连通块,所以最终的答案还要算上\(i\)的连通块的贡献,那么答案就是

\[k\sum_{1=1}^n f(1,j)j
\]

然而这样的复杂度是\(O(n^2)\),还是要\(T\),得继续优化

我们设一个生成函数

\[f_u(x)=\sum_i f(u,i)x^i
\]

发现这个生成函数有一个特点,就是\(f'_u(1)\)恰好就是\(\sum_{1=1}^n f(u,j)j\),其中前者表示对\(f_u(x)\)求导之后用\(1\)代入\(x\)。那么最后的答案就是\(kf'_1(1)\)

考虑转移,为

\[f_u(x)=f_u(x)(kf'_v(1)+f_v(x))
\]

设\(g_u=f'_u(1)\)

\[f'_u=(f_ug_v+f_ug_u)'
\]

\[f'_u=f'_ug_v+f'_uf_v+f_uf'_v
\]

\[f'_u(1)=f'_u(1)g_v+f'_u(1)f_v(1)+f_u(1)f'_v(1)
\]

\[g_u=g_ug_v+g_uf_v(1)+f_u(1)g_v
\]

如果设\(h_u=f_u(1)\),那么转移式子可以写成

\[g_u=g_ug_v+g_uh_v+h_ug_v
\]

\[h_u=h_ug_v+h_uh_v
\]

边界条件为\(h_u=1,g_u=k\)

那么就可以\(O(n)\)树形\(dp\)了

最后把输出\({g_1p^ny^n\over n^2}\)

子任务\(2\)

我们发现其实子任务\(1\)里的容斥依然可以用在这里,唯一的区别就是我们要把\(C(T)\)改成\(C^2(T)\),那么这里\(C^2(T)\)就表示交集至少为\(T\)的二元组\(T1,T2\)的对数。不过注意这里要枚举的是所有边的子集\(T\)

那么答案就变成了

\[ans={p^n\over n^4}\sum_{T}\prod_{i=1}^{n-|T|}{a_i^2n^2\over p}
\]

然后设\(k={n^2\over p}\),问题就变成了:一个大小为\(i\)的树的权值为\(ki^2\),一个森林的权值为所有树得权值之积,求所有森林的权值之和

我们先考虑树的权值,大小为\(i\)的无根树总共有\(i^{i-2}\)个,每一个贡献为\(ki^2\),那么大小为\(i\)的树的权值总和就是\(ki^i\),它的指数型生成函数就是

\[A(x)=\sum_{i=1}^\infty {ki^i\over i!}x^i
\]

对于每一个森林,都是由若干棵树构成的,且这些树之间是无序的。那么根据指数型生成函数的意义,我们可以得知森林权值的指数型生成函数为

\[B(x)=e^{A(x)}
\]

那么只要做一次多项式\(exp\)就行了,然后取第\(n\)项,把答案乘上\(\frac{p^nbas^{n}n!}{n^4}\)

以上是题解,以下是吐槽

这题知识点还真多……多项式板子打错之后只能干瞪眼果然不是瞎吹的……用了一天看懂题解,用了一天把代码调出来……这才是真正的数数题么……

//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
R int res,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
const int N=1e5+5,P=998244353,Gi=332748118;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
R int res=1;
for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
return res;
}
int n,bas,op;
namespace solver0{
struct node{
int u,v;
node(R int uu=0,R int vv=0){u=min(uu,vv),v=max(uu,vv);}
inline bool operator <(const node &b)const{return u==b.u?v<b.v:u<b.u;}
inline bool operator ==(const node &b)const{return u==b.u&&v==b.v;}
}e[N<<1];int tot,u,v,res;
void MAIN(){
fp(i,1,n-1)u=read(),v=read(),e[++tot]=node(u,v);
fp(i,1,n-1)u=read(),v=read(),e[++tot]=node(u,v);
sort(e+1,e+1+tot),tot=unique(e+1,e+1+tot)-e-1;
res=(n-1<<1)-tot,res=n-res,printf("%d\n",ksm(bas,res));
}
}
namespace solver1{
struct eg{int v,nx;}e[N<<1];int head[N],tot;
inline void add_edge(R int u,R int v){e[++tot]={v,head[u]},head[u]=tot;}
int f[N],g[N],u,v,res,p,k;
void dfs(int u,int fa){
f[u]=1,g[u]=k;
go(u)if(v!=fa){
dfs(v,u),p=add(f[v],g[v]);
g[u]=(1ll*g[u]*p+1ll*g[v]*f[u])%P;
f[u]=mul(f[u],p);
}
}
void MAIN(){
if(bas==1)return printf("%d\n",ksm(n,n-2)),void();
res=dec(ksm(bas,P-2),1),k=mul(n,ksm(res,P-2));
fp(i,1,n-1)u=read(),v=read(),add_edge(u,v),add_edge(v,u);
dfs(1,0);
printf("%d\n",1ll*g[1]*ksm(mul(res,bas),n)%P*ksm(mul(n,n),P-2)%P);
}
}
namespace solver2{
const int N=5e5+5;
int r[N],O[N],inv[N],fac[N],ifac[N],f[N],g[N],l,lim,res;
void init(R int len){
lim=1,l=0;while(lim<len)lim<<=1,++l;
fp(i,0,lim-1)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
void NTT(int *A,int ty,int len=0){
fp(i,0,lim-1)if(i<r[i])swap(A[i],A[r[i]]);
for(R int mid=1;mid<lim;mid<<=1){
int I=(mid<<1),Wn=ksm(ty==1?3:Gi,(P-1)/I);O[0]=1;
fp(i,1,mid-1)O[i]=mul(O[i-1],Wn);
for(R int j=0;j<lim;j+=I)fp(k,0,mid-1){
int x=A[j+k],y=mul(O[k],A[j+k+mid]);
A[j+k]=add(x,y),A[j+k+mid]=dec(x,y);
}
}
if(ty==-1)for(R int i=0,inv=ksm(lim,P-2);i<lim;++i)A[i]=mul(A[i],inv);
}
void Inv(int *a,int *b,int len){
if(len==1)return b[0]=ksm(a[0],P-2),void();
Inv(a,b,len>>1);
static int A[N],B[N];init(len<<1);
fp(i,0,len-1)A[i]=a[i],B[i]=b[i];
fp(i,len,lim-1)A[i]=B[i]=0;
NTT(A,1),NTT(B,1);
fp(i,0,lim-1)A[i]=mul(A[i],mul(B[i],B[i]));
NTT(A,-1);
fp(i,0,len-1)b[i]=dec(add(b[i],b[i]),A[i]);
fp(i,len,lim-1)b[i]=0;
}
void Ln(int *a,int *b,int len){
static int A[N],B[N];
fp(i,1,len-1)A[i-1]=mul(a[i],i);A[len-1]=0;
Inv(a,B,len);init(len<<1);
fp(i,len,lim-1)A[i]=B[i]=0;
NTT(A,1),NTT(B,1);
fp(i,0,lim-1)A[i]=mul(A[i],B[i]);
NTT(A,-1);
fp(i,1,len-1)b[i]=mul(A[i-1],inv[i]);b[0]=0;
fp(i,len,lim-1)b[i]=0;
}
void Exp(int *a,int *b,int len){
if(len==1)return b[0]=1,void();
Exp(a,b,len>>1);
static int A[N];
Ln(b,A,len);init(len<<1);
A[0]=dec(a[0]+1,A[0]);
fp(i,1,len-1)A[i]=dec(a[i],A[i]);
fp(i,len,lim-1)A[i]=b[i]=0;
NTT(A,1),NTT(b,1);
fp(i,0,lim-1)b[i]=mul(b[i],A[i]);
NTT(b,-1);
fp(i,len,lim-1)b[i]=0;
}
void Pre(int len){
inv[0]=inv[1]=1;fp(i,2,len)inv[i]=mul(P-P/i,inv[P%i]);
fac[0]=fac[1]=ifac[0]=1;fp(i,2,len)fac[i]=mul(fac[i-1],i);
ifac[len]=ksm(fac[len],P-2);fd(i,len-1,1)ifac[i]=mul(ifac[i+1],i+1);
}
void MAIN(){
if(bas==1)return printf("%d\n",mul(ksm(n,n-2),ksm(n,n-2))),void();
int len=1;while(len<=n)len<<=1;Pre(len);
int p=dec(ksm(bas,P-2),1),k=1ll*n*n%P*ksm(p,P-2)%P;
fp(i,1,len-1)f[i]=1ll*k*ksm(i,i)%P*ifac[i]%P;
Exp(f,g,len);
res=1ll*g[n]*fac[n]%P*ksm(p,n)%P*ksm(bas,n)%P*ksm(1ll*n*n%P*n%P*n%P,P-2)%P;
printf("%d\n",res);
}
}
int main(){
// freopen("testdata.in","r",stdin);
n=read(),bas=read(),op=read();
switch(op){
case 0:solver0::MAIN();break;
case 1:solver1::MAIN();break;
case 2:solver2::MAIN();break;
}
return 0;
}