BZOJ 3992: [SDOI2015]序列统计 NTT+快速幂

时间:2021-12-03 08:00:07

3992: [SDOI2015]序列统计

Time Limit: 30 Sec  Memory Limit: 128 MB
Submit: 1155  Solved: 532
[Submit][Status][Discuss]

Description

小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。

Input

一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。第二行,|S|个整数,表示集合S中的所有元素。

Output

一行,一个整数,表示你求出的种类数mod 1004535809的值。

Sample Input

4 3 1 2
1 2

Sample Output

8

HINT

【样例说明】
可以生成的满足要求的不同的数列有(1,1,1,1)、(1,1,2,2)、(1,2,1,2)、(1,2,2,1)、(2,1,1,2)、(2,1,2,1)、(2,2,1,1)、(2,2,2,2)。
【数据规模和约定】
对于10%的数据,1<=N<=1000;
对于30%的数据,3<=M<=100;
对于60%的数据,3<=M<=800;
对于全部的数据,1<=N<=109,3<=M<=8000,M为质数,1<=x<=M-1,输入数据保证集合S中元素不重复
 

Source

Round 1 感谢yts1999上传

想法:

设a[i]表示数字i是否属于集合S,C[i]表示数列之积%M=i的方案数。  当n=2时:
  c[(i*j)%M]=∑a[i]*a[j]
  令A[i]=a[g^i],C[i]=c[g^i]//∵g为M原根,遍历0~M-1,而将数组映射到另一个数组,并不影响答案,只要改变运算规则。
  由 c[g^(i+j)%M]=∑a[g^i]*a[g^j] 得到:
  C[(i+j)%(M-1)]=∑A[i]*A[j]//费马小定理:g^(M-1)

  ∵j+i≤m*2
  ∴每次FFT后将后面的累加到前面来就行了。
  当n=y时,C=A^y,找到g^j=x,输出C[j] 于是NTT+快速幂O(nlog^2n)
 #include<cstdio>
#define ll long long
const int MP(),lem(),g();
int n,m,x,size,gm;
int a[lem+],num,p[],tp;
struct data{int a[lem+];}A,C,B;
int power(int a,int b,int MP)
{
ll t=,y=a;b+=b<?MP-:;
for(;b;b>>=,y=(y*y)%MP)if(b&)t=(t*y)%MP;
return (int)t;
}
bool check(int y)
{
for(int j=;j<=tp;j++)
if(power(y,(m-)/p[j],m)==)return false;
return true;
}
void Get_g(int x)
{
if(!(x&))
{
p[++tp]=;
while(!(x&))x>>=;
}
for(int i=;i*i<=x;i+=)
{
if(x%i==)
{
p[++tp]=i;
while(x%i==)x/=i;
}
}
if(x>)p[++tp]=x;
for(int i=;i<=m-;i++)
if(check(i)){gm=i;break;}
}
int R[lem+],w[lem+],wn,l,il,h;
void deal()
{
l=;w[]=;
while(l<=m+m)l<<=,h++;
for(int i=;i<l;i++)R[i]=(R[i>>]>>|(i&)<<(h-));
il=power(l,MP-,MP);
}
void swap(int &a,int &b){if(a==b)return;a^=b;b^=a;a^=b;}
void NTT(int *a,int l,int ty)
{
for(int i=;i<l;i++)if(i<R[i])swap(a[i],a[R[i]]);
for(int leng=;leng<=l;leng<<=)
{
int M=leng>>;
wn=power(g,ty*(MP-)/leng,MP);
for(int i=;i<M;i++)w[i]=(1ll*w[i-]*wn)%MP;
for(int i=;i<l;i+=leng)
{
for(int j=;j<M;j++)
{
int x=a[i+j],y=(1ll*w[j]*a[i+j+M])%MP;
a[i+j]=x+y;a[i+j+M]=x-y;
a[i+j]-=a[i+j]>=MP?MP:;
a[i+j+M]+=a[i+j+M]<?MP:;
}
}
}
if(ty==-)
for(int i=;i<l;i++)a[i]=(1ll*a[i]*il)%MP;
}
void three(data &A,data &B)
{
NTT(A.a,l,);NTT(B.a,l,);
for(int i=;i<l;i++)B.a[i]=(1ll*B.a[i]*A.a[i])%MP;
NTT(B.a,l,-);
for(int i=m-;i<l;i++)B.a[i%(m-)]=(B.a[i%(m-)]+B.a[i])%MP,B.a[i]=;
}
void run()
{
C.a[]=;
while(n)
{
if(n&)
{
B=A;
three(B,C);
}
B=A;
three(B,A);
n>>=;
}
}
int main()
{
scanf("%d%d%d%d",&n,&m,&x,&size);
for(int i=;i<=size;i++){scanf("%d",&num);a[num]=;}
Get_g(m-);deal();
for(int i=,j=;i<m-;i++,j=(j*gm)%m)
{
A.a[i]=a[j];
if(j==x)num=i;
}
run();
printf("%d",C.a[num]);
return ;
}