Description
Solution
考虑用倍增来处理答案:
设 \(f[i][j]\) 表示长度恰好为 \(2^{i}\) 的哈希值为 \(j\) 的字符串的种数
\(dp[i][j]\) 表示长度小于等于 \(2^{i}\) 的哈希值为 \(j\) 的字符串的种数
容易得到转移式子:
\(f[i+1][j*base^{2^{i}}+k]=\sum f[i][j]*f[i][k]\)
\(dp[i+1][j*base^{2^{i}}+k]=dp[i][j*base^{2^{i}}+k]+\sum f[i][j]*dp[i][k]\)
发现两个转移是一个卷积的形式,\(NTT\) 优化转移即可
最后就是得到长度 \(<=n\) 的答案
可以像 \(dp\) 数组的求法一样,直接倍增求出即可
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=200005,mod=998244353;
inline int qm(int x,int k){
int sum=1;
while(k){
if(k&1)sum=1ll*sum*x%mod;
x=1ll*x*x%mod;k>>=1;
}
return sum;
}
int L,R[N],inv,n,P,D,len,st[N],top=0,ans[N];ll b[N];
inline void init(){
for(n=1;n<=(P<<1);n<<=1)L++;
for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
inv=qm(n,mod-2);
}
inline void NTT(int *A,int o){
for(int i=0;i<n;i++)if(i<R[i])swap(A[i],A[R[i]]);
for(int i=1;i<n;i<<=1){
int t0=qm(3,(mod-1)/(i<<1)),x,y;
for(int j=0;j<n;j+=(i<<1)){
int t=1;
for(int k=0;k<i;k++,t=1ll*t0*t%mod){
x=A[j+k];y=1ll*A[i+j+k]*t%mod;
A[j+k]=(x+y)%mod;A[j+k+i]=(x-y+mod)%mod;
}
}
}
if(o==-1)reverse(A+1,A+n);
}
inline void mul(int *A,int *B){
NTT(A,1);NTT(B,1);
for(int i=0;i<=n;i++)A[i]=1ll*A[i]*B[i]%mod;
NTT(A,-1);
}
int f[20][N],dp[20][N],A[N],B[N];
inline void Modify(int i){
for(int j=0;j<n;j++)A[j]=B[j]=0;
for(int j=0;j<P;j++)A[j*b[i]%P]=(A[j*b[i]%P]+ans[j])%mod;
for(int j=0;j<P;j++)B[j]=(B[j]+f[i][j])%mod;
mul(A,B);
for(int j=0;j<P;j++)ans[j]=dp[i][j];
for(int j=0;j<n;j++)ans[j%P]=(ans[j%P]+1ll*A[j]*inv)%mod;
}
int main(){
freopen("pp.in","r",stdin);
freopen("pp.out","w",stdout);
cin>>len>>b[0]>>P>>D;
init();
for(int i='a';i<='z';i++)dp[0][i%P]++,f[0][i%P]++;
for(int i=0;(1<<(i+1))<=len;i++){
b[i+1]=b[i]*b[i]%P;
for(int j=0;j<n;j++)A[j]=B[j]=0;
for(int j=0;j<P;j++)A[j*b[i]%P]=(A[j*b[i]%P]+f[i][j])%mod;
for(int j=0;j<P;j++)B[j]=(B[j]+f[i][j])%mod;
mul(A,B);
for(int j=0;j<n;j++)f[i+1][j%P]=(f[i+1][j%P]+1ll*A[j]*inv)%mod;
for(int j=0;j<n;j++)A[j]=B[j]=0;
for(int j=0;j<P;j++)A[j*b[i]%P]=(A[j*b[i]%P]+dp[i][j])%mod;
for(int j=0;j<P;j++)B[j]=(B[j]+f[i][j])%mod;
mul(A,B);
for(int j=0;j<n;j++)dp[i+1][j%P]=(dp[i+1][j%P]+1ll*A[j]*inv)%mod;
for(int j=0;j<P;j++)dp[i+1][j]=(dp[i+1][j]+dp[i][j])%mod;
}
for(int i=20;i>=0;i--)
if((1<<i)<=len)len-=(1<<i),st[++top]=i;
for(int i=0;i<P;i++)ans[i]=dp[st[top]][i];
while(--top)Modify(st[top]);
printf("%d\n",ans[D]);
return 0;
}