NTT+多项式求逆+多项式开方(BZOJ3625)

时间:2022-11-18 15:30:47

定义多项式$h(x)$的每一项系数$h_i$,为i在c[1]~c[n]中的出现次数。

定义多项式$f(x)$的每一项系数$f_i$,为权值为i的方案数。

通过简单的分析我们可以发现:$f(x)=\frac{2}{\sqrt{1-4h(x)}+1}$

于是我们需要多项式开方和多项式求逆。

多项式求逆:

求$B(x)$,使得$A(x)*B(x)=1\;(mod\;x^m)$

考虑倍增。

假设我们已知$A(x)*B(x)=1\;(mod\;x^m)$,要求$C(x)$,使得$A(x)*C(x)=1\;(mod\;x^{2m})$

简单分析可得$C(x)=B(x)*(2-A(x)*B(x))$

多项式开方:

求$B(x)$,使得$B(x)*B(x)=A(x)\;(mod\;x^m)$

继续考虑倍增。

假设我们已知$B(x)*B(x)=A(x)\;(mod\;x^m)$,要求$C(x)$,使得$C(x)*C(x)=A(x)\;(mod\;x^{2m})$

简单分析可得$C(x)=\frac{B(x)^2+A(x)}{2B(x)}$,需要求$B(x)$的逆。

观察到以上两个式子都用到了多项式乘法,又因为在模意义下,我们需要NTT。

NTT和FFT的区别在于单位根不一样,NTT需要找一个p的原根,而且通常要求p是形如$k*2^q+1$的一个质数。

找原根暴力找就行了。

(这题卡常数,所以我都用的int)

#include <cstdio>
#include <cstring>
#include <algorithm> typedef long long ll;
const int N=,p=;
int n,m,x,ni,r[N],h[N],_h[N],__h[N],t[N];
ll pw(ll a,int b) {
ll r=;
for(;b;b>>=,a=a*a%p) if(b&) r=r*a%p;
return r;
} void ntt(int *a,int n,int f) {
for(int i=;i<n;i++) if(r[i]>i) std::swap(a[i],a[r[i]]);
for(int i=;i<=n;i<<=)
for(int j=,wn=pw(,((p-)/i*f+p-)%(p-)),m=i>>;j<n;j+=i)
for(int k=,w=;k<m;k++,w=(ll)w*wn%p) {
int x=a[j+k],y=(ll)a[j+k+m]*w%p;
a[j+k]=(x+y)%p,a[j+k+m]=(x-y+p)%p;
}
if(f==-) {
ll ni=pw(n,p-);
for(int i=;i<n;i++) a[i]=(ll)a[i]*ni%p;
}
}
void gt2(int n) {
if(n==) {__h[]=pw(_h[],p-); return;}
gt2(n>>),memcpy(t,_h,sizeof(int)*n),memset(t+n,,sizeof(int)*n);
int m=,l=-;
while(m<n<<) m<<=,l++;
for(int i=;i<m;i++) r[i]=(r[i>>]>>)|((i&)<<l);
ntt(t,m,),ntt(__h,m,);
for(int i=;i<m;i++) __h[i]=(ll)__h[i]*(-(ll)t[i]*__h[i]%p+p)%p;
ntt(__h,m,-),memset(__h+n,,sizeof(int)*n);
}
void gt(int n) {
if(n==) {_h[]=; return;}
gt(n>>),memset(__h,,sizeof(int)*n),gt2(n);
memcpy(t,h,sizeof(int)*n),memset(t+n,,sizeof(int)*n);
int l=-,m=;
while(m<n<<) m<<=,l++;
for(int i=;i<m;i++) r[i]=(r[i>>]>>)|((i&)<<l);
ntt(t,m,),ntt(_h,m,),ntt(__h,m,);
for(int i=;i<m;i++) _h[i]=((ll)_h[i]*_h[i]+t[i])%p*__h[i]%p*ni%p;
ntt(_h,m,-),memset(_h+n,,sizeof(int)*n);
} int main() {
scanf("%d%d",&m,&n),ni=pw(,p-);
for(int i=;i<=m;i++) scanf("%d",&x),h[x]++;
for(int i=;i<=n;i++) h[i]=(-h[i]*+p)%p;
for(m=n,n=;n<=m;n<<=);
h[]++,gt(n),_h[]=(_h[]+)%p,memset(__h,,sizeof __h),gt2(n);
for(int i=;i<=m;i++) printf("%d\n",__h[i]*%p);
return ;
}