洛谷 3784(bzoj 4913) [SDOI2017]遗忘的集合——多项式求ln+MTT

时间:2022-11-07 04:29:24

题目:https://www.luogu.org/problemnew/show/P3784

   https://www.lydsy.com/JudgeOnline/problem.php?id=4913

和洛谷3489“付公主的背包”一样的套路。

要设 a[ i ] 表示第 i 个值有没有出现。

然后就有 \( \prod\limits_i(\frac{1}{1-x^i})^{a_i} = f(x) \)

因为有 \( \prod \) ,所以两边取 ln 。

\( \sum\limits_{i}a_{i}ln(\frac{1}{1-x^i}) = ln(f(x)) \)

现在想求一个 \( ln(\frac{1}{1-x^i}) \) 的更优美的形式(一般是形如 \( \sum \) 的),来更简单地刻画 a[ i ] 和 f[ i ] 的关系。(f[ i ] 是 ln( f(x) ) 的第 i 项系数)

因为有 \( ln \) ,所以先求导再积分来化式子。

并且 \( \frac{f'(x)}{f(x)} \) 了之后,把 \( f'(x) \) 写成 \( \sum \) 的形式,用 \( f(x) \) 和 \( \int \) 化出一个更好看的 \( \sum \) 的式子。

\( \int (1-x^i)\sum\limits_{j=1}i*j*x^{i*j-1} \) // j 从 1 开始

\( = \int \sum\limits_{j=1}i*j*x^{i*j-1} - \sum\limits_{j=1}i*j*x^{i*(j+1)-1} \)

\( = \int \sum\limits_{j=1}i*x^{i*j-1} \)

\( = \sum\limits_{j=0}\frac{1}{j}*x^{i*j} \)

所以 \( \sum\limits_{i=1}a_i\sum\limits_{j=0}\frac{1}{j}x^{i*j} = ln(f(x)) \)

  \( \sum\limits_{i=1}\sum\limits_{j=0}a_i*\frac{1}{j} = f[i*j] \)

\( f[i]=\sum\limits_{j\|i}a_j*\frac{j}{i} \)

把分母的 i 乘到左边,然后莫比乌斯反演一下就知道 \( a_i *i= \sum\limits_{j\|i}f[j]*j*u(i/j) \)

实现的时候要写 MTT 。写拆系数 FFT 的话需要 long double 。自己写的三模数 NTT 还没调出来,不知是哪里出错。

有许许多多的细节需要注意。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define db long double
#define ll long long
using namespace std;
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='')ret=ret*+ch-'',ch=getchar();
return fx?ret:-ret;
}
const int N=(<<)+;
int n,p,f[N],g[N],u[N],pri[N]; bool vis[N]; int upt(int x){if(x>=p)x-=p;if(x<)x+=p;return x;}
int pw(int x,int k)
{int ret=;while(k){if(k&)ret=(ll)ret*x%p;x=(ll)x*x%p;k>>=;}return ret;} namespace poly{
const db pi=acos(-);
struct cpl{
db x,y;
cpl(db x=,db y=):x(x),y(y) {}
cpl operator+ (const cpl &b)const
{return cpl(x+b.x,y+b.y);}
cpl operator- (const cpl &b)const
{return cpl(x-b.x,y-b.y);}
cpl operator* (const cpl &b)const
{return cpl(x*b.x-y*b.y,x*b.y+y*b.x);}
cpl operator/ (const int &b)const
{return cpl(x/b,y/b);}
};
cpl conj(cpl a){return cpl(a.x,-a.y);}
int len,r[N],inv[N]; cpl Wn[N];
int bs,pbs,bs2; cpl pa[N],pb[N],pc[N],pd[N];
int A[N],B[N],tp[N]; void init()
{
int tmp=sqrt(p);
for(bs=,pbs=;pbs<=tmp;bs++,pbs<<=);
bs2=bs<<; pbs--;
}
void fft_pre()
{
for(int i=,j=len>>;i<len;i++)
r[i]=(r[i>>]>>)+((i&)?j:);
for(int R=,m=;R<=len;m=R,R<<=)
Wn[R]=cpl( cos(pi/m),sin(pi/m) );
}
void fft(cpl *a,bool fx)
{
for(int i=;i<len;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(int R=;R<=len;R<<=)
{
cpl wn=fx?conj(Wn[R]):Wn[R];
for(int i=,m=R>>;i<len;i+=R)
{
cpl w=cpl(,);
for(int j=;j<m;j++,w=w*wn)
{
cpl x=a[i+j], y=w*a[i+m+j];
a[i+j]=x+y; a[i+m+j]=x-y;
}
}
}
if(!fx)return;
for(int i=;i<len;i++)a[i]=a[i]/len;
}
void mtt(int n1,int *a,int n2,int *b,int *c)
{
int n3=n1+n2-;
for(len=;len<n3;len<<=); fft_pre();
//for(int i=0;i<n1;i++) pa[i]=cpl(a[i]>>15,a[i]&32767);
//for(int i=0;i<n2;i++) pb[i]=cpl(b[i]>>15,b[i]&32767);
for(int i=;i<n1;i++) pa[i]=cpl(a[i]>>bs,a[i]&pbs);
for(int i=;i<n2;i++) pb[i]=cpl(b[i]>>bs,b[i]&pbs);
for(int i=n1;i<len;i++) pa[i]=cpl(,);
for(int i=n2;i<len;i++) pb[i]=cpl(,);
fft(pa,); fft(pb,);
pa[len]=pa[]; pb[len]=pb[];
for(int i=,j=len;i<len;i++,j--)//q[i]=conj(p[j])
{
cpl ta=(pa[i]+conj(pa[j]))*cpl(0.5,);//conj(*[j])!!
cpl tb=(pa[i]-conj(pa[j]))*cpl(,-0.5);
cpl tc=(pb[i]+conj(pb[j]))*cpl(0.5,);
cpl td=(pb[i]-conj(pb[j]))*cpl(,-0.5);
pc[i]=ta*tc+ta*td*cpl(,);
pd[i]=tb*tc+tb*td*cpl(,);
}
pa[]=pb[]=cpl(,);
fft(pc,); fft(pd,);
for(int i=;i<n3;i++)
{
ll ta=(ll)(pc[i].x+0.5)%p;
ll tb=(ll)(pc[i].y+0.5)%p;
ll tc=(ll)(pd[i].x+0.5)%p;
ll td=(ll)(pd[i].y+0.5)%p;
c[i]=((ta<<bs2)+((tb+tc)<<bs)+td)%p;
//c[i]=((ta<<30)+((tb+tc)<<15)+td)%p;
}
}
void get_dao(int n,int *a,int *b)
{
for(int i=;i<n;i++)b[i-]=(ll)a[i]*i%p;
b[n-]=;
}
void get_jf(int n,int *a,int *b)
{
inv[]=;
for(int i=;i<n;i++)inv[i]=(ll)(p-p/i)*inv[p%i]%p;//(p-..)!
for(int i=n-;i;i--)b[i]=(ll)a[i-]*inv[i]%p;//i-- for a==b
b[]=;
}
void get_inv(int n,int *a,int *b)
{
b[]=pw(a[],p-);
for(int l=;l<=n;l<<=)
{
for(int i=l>>;i<l;i++)b[i]=;/////
mtt(l,a,l,b,tp);
mtt(l,b,l,tp,tp);/////b*tp not a*tp
for(int i=;i<l;i++)
b[i]=((ll)b[i]*-tp[i]+p)%p;
}
}
void get_ln(int n,int *a,int *b)
{
get_dao(n,a,A); get_inv(n,a,B);
mtt(n,A,n,B,A);
get_jf(n,A,b);
}
}
void get_mu(int n)
{
int cnt=; u[]=;
for(int i=,d;i<=n;i++)
{
if(!vis[i])pri[++cnt]=i,u[i]=-;
for(int j=;j<=cnt&&(d=i*pri[j])<=n;j++)
{
vis[d]=; u[d]=-u[i];
if(i%pri[j]==){u[d]=; break;}
}
}
}
int main()
{
n=rdn();p=rdn(); poly::init();//
for(int i=;i<=n;i++)f[i]=rdn(); f[]=;//f[0]=1
int l=;for(;l<=n;l<<=);//<=n
poly::get_ln(l,f,f); get_mu(n);
for(int i=;i<=n;i++)f[i]=(ll)f[i]*i%p;
for(int i=;i<=n;i++)
for(int j=,k=i;k<=n;j++,k+=i)
g[k]=upt(g[k]+f[i]*u[j]);
int cnt=;
for(int i=;i<=n;i++)if(g[i])cnt++;
printf("%d\n",cnt);
for(int i=;i<=n;i++)if(g[i])printf("%d ",g[i]);
puts(""); return ;
}

拆系数FFT

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='')ret=ret*+ch-'',ch=getchar();
return fx?ret:-ret;
}
int upt(int x,int mod)
{while(x>=mod)x-=mod;while(x<)x+=mod;return x;}
int pw(int x,int k,int mod)
{int ret=;while(k){if(k&)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=;}return ret;} const int N=(<<)+; int p;
namespace poly{
const double eps=1e-;
int m[]={,,};
ll M=(ll)m[]*m[], A[N],B[N],C[][N];
int len,r[N],Wn[N][],inv[N];
int tp[N],ta[N],tb[N];
ll mul(ll a,ll b,ll mod)
{
a=(a%mod+mod)%mod; b=(b%mod+mod)%mod;/////
ll ret=(a*b- (ll)((long double)a/mod*b+eps) *mod)%mod;
if(ret<)ret+=mod; return ret;
}
void ntt_pre(int len,int mod)
{
for(int R=;R<=len;R<<=)
Wn[R][]=pw( ,(mod-)/R,mod ),
Wn[R][]=pw( ,(mod-)-(mod-)/R,mod );
}
void ntt(ll *a,bool fx,int mod)
{
for(int i=;i<len;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(int R=;R<=len;R<<=)
{
int wn=Wn[R][fx];
for(int i=,m=R>>;i<len;i+=R)
for(int j=,w=;j<m;j++,w=(ll)w*wn%mod)
{
int x=a[i+j], y=(ll)w*a[i+m+j]%mod;
a[i+j]=upt(x+y,mod); a[i+m+j]=upt(x-y,mod);
}
}
if(!fx)return; int inv=pw(len,mod-,mod);
for(int i=;i<len;i++)a[i]=(ll)a[i]*inv%mod;
}
void mtt(int n,int *a,int n2,int *b,int *c)//ok if c==a||c==b
{
for(len=;len<n+n2;len<<=); int mod;
for(int i=,j=len>>;i<len;i++)
r[i]=(r[i>>]>>)+((i&)?j:);
for(int i=;i<;i++)
{
mod=m[i];
for(int j=;j<n;j++)A[j]=a[j];
for(int j=n;j<len;j++)A[j]=;
for(int j=;j<n2;j++)B[j]=b[j];
for(int j=n2;j<len;j++)B[j]=;
ntt_pre(len,mod);
ntt(A,,mod); ntt(B,,mod);
for(int j=;j<len;j++)C[i][j]=(ll)A[j]*B[j]%mod;
ntt(C[i],,mod);
}
len=n+n2-;//n-1 + m-1 = n+m-2
mod=m[]; int tm=m[],inv=pw(tm,mod-,mod);
for(int i=;i<len;i++)
{
int tmp=(ll)upt(C[][i]-C[][i],mod)*inv%mod;
c[i]=((ll)tmp*tm+C[][i])%M;
}
mod=p; tm=m[]; inv=pw(M%tm,tm-,tm);
for(int i=;i<len;i++)
{
int tmp=mul((C[][i]-c[i])%tm+tm,inv,tm);
c[i]=(mul(tmp,M,mod)+c[i])%mod;
}
}
void get_dao(int n,int *a,int *b)
{
for(int i=;i<n;i++)b[i-]=(ll)a[i]*i%p;
b[n-]=;
}
void get_jf(int n,int *a,int *b)
{
inv[]=;
for(int i=;i<n;i++)inv[i]=(ll)(p-p/i)*inv[p%i]%p;//p/i
for(int i=n-;i;i--)b[i]=(ll)a[i-]*inv[i]%p;//i--:a==b
b[]=;
}
void get_inv(int n,int *a,int *b)//tb[]
{
b[]=pw(a[],p-,p);
for(int l=,tn=;tn<n;tn=l,l<<=)
{
for(int i=tn;i<l;i++)b[i]=;
mtt(l,a,l,b,tb);
mtt(l,b,l,tb,tb);
for(int i=;i<l;i++)
b[i]=((ll)b[i]*-tb[i]+p)%p;
}
}
void get_ln(int n,int *a,int *b)//ta[],tp[]//ok if b==a
{//%x^n
get_dao(n,a,ta); get_inv(n,a,tp);
mtt(n,ta,n,tp,ta);
get_jf(n,ta,b);
}
} int n,f[N],ans[N],mu[N],pri[N]; bool vis[N];
void get_mu(int n)
{
mu[]=; int cnt=;
for(int i=;i<=n;i++)
{
if(!vis[i])pri[++cnt]=i,mu[i]=-;
for(int j=,d;j<=cnt&&(d=i*pri[j])<=n;j++)
{
vis[d]=;
if(i%pri[j]==){mu[d]=;break;}
mu[d]=-mu[i];
}
}
}
int main()
{
n=rdn();p=rdn();
for(int i=;i<=n;i++)f[i]=rdn(); f[]=;//f[0]=1
poly::get_ln(n+,f,f);//n+1
for(int i=;i<=n;i++)f[i]=(ll)f[i]*i%p;
get_mu(n);
for(int i=;i<=n;i++)
for(int j=,k=i;k<=n;j++,k+=i)
ans[k]=upt(ans[k]+mu[j]*f[i],p);
int cnt=;
for(int i=;i<=n;i++)if(ans[i])cnt++;
printf("%d\n",cnt);
for(int i=;i<=n;i++)if(ans[i])printf("%d ",ans[i]);
puts(""); return ;
}

三模数NTT(TLE+WA)