3992: [SDOI2015]序列统计

时间:2024-11-09 17:36:38

3992: [SDOI2015]序列统计

链接

分析:

  给定一个集和s,求多少个长度为n的序列,满足序列中每个数都属于s,并且所有数的乘积模m等于x。

  设$f=\sum\limits_{i=0}^{n - 1} a_i x ^ i \ \ 如果集合中存在i,a_i = 1$

  那么答案的生成函数为f自乘n次,这里可以快速幂。这里"乘法"定义是:设多项式a乘多项式b等于c,$\sum\limits_{k=0}^{n - 1} c_k = \sum\limits_{i \times j = k} a_i \times b_j$ 每次“乘法”的复杂度是$m^2$,所以复杂度是$O(m^2logn)$。

  考虑优化“乘法”的部分,我们知道多项式乘法利用FFT/NTT可以做到$nlogn$的,看能否转化为多项式乘法,即多项式乘法的定义变为$\sum\limits_{k=0}^{n - 1} c_k = \sum\limits_{i + j = k} a_i \times b_j$。

  NTT中,有引入原根的概念,在NTT中,原根的用途相当于单位根。 原根有一个性质:对于mod p下的原根g,$g^1, g^2 \dots g^{p - 1}$互不相同,$g^{p - 1} \equiv 1 \mod p$。而且$g^1, g^2 \dots g^{p - 1}$可以分别表示$1,2 \dots p - 1$。

  那么我们对m求出单位根,集合S中出现的每个数,都可以表示为$s_i = g^{t_{s_i}}$

  此时对于原来的一个序列y,$\prod y_i = x \mod m$,就变成了$\prod g ^{t_{y_i}} = g^{t_x} \mod m$,即$\sum t_{y_i} = x \mod m - 1$

  现在我们求的就是长度为n的序列,序列中每个数都属于集合t,并且所有数的和模(m-1)等于x 如此按照上面的做法,将乘法的定义改为多项式乘法的定义,快速幂+NTT即可复杂度$mlogmlogn$。

  注意:多项式乘法中是没有取模的,而这里(i+j)%(m-1),直接将数组加倍,然后NTT完后,大于等于m的加到相应的模m后的位置上即可。

代码:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cctype>
#include<set>
#include<queue>
#include<vector>
#include<map>
using namespace std;
typedef long long LL; inline int read() {
int x=,f=;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-;
for(;isdigit(ch);ch=getchar())x=x*+ch-'';return x*f;
} const int mod = ;
const int N = ;
int vis[N], rev[N], n = , m;
int f[N], g[N], a[N], b[N], inv; int ksm(int a,int b,int p) {
a %= p;
int ans = ;
while (b) {
if (b & ) ans = 1ll * ans * a % p;
a = 1ll * a * a % p;
b >>= ;
}
return ans % p;
}
int Calc(int x) {
if (x == ) return ;
for (int i = ; ; ++i) {
bool flag = ;
for (int j = ; j * j < x; ++j)
if (ksm(i, (x - ) / j, x) == ) { flag = false; break; }
if (flag) return i;
}
}
void NTT(int *a,int n,int ty) {
for (int i = ; i < n; ++i) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int m = ; m <= n; m <<= ) {
int w1 = ksm(, (mod - ) / m, mod);
if (ty == -) w1 = ksm(w1, mod - , mod);
for (int i = ; i < n; i += m) {
int w = ;
for (int k = ; k < (m >> ); ++k) {
int u = a[i + k], t = 1ll * w * a[i + k + (m >> )] % mod;
a[i + k] = (u + t) % mod;
a[i + k + (m >> )] = (u - t + mod) % mod;
w = 1ll * w * w1 % mod;
}
}
}
}
void mul(int *g,int *f) {
for (int i = ; i < n; ++i) a[i] = g[i] % mod, b[i] = f[i] % mod;
NTT(a, n, );
NTT(b, n, );
for (int i = ; i < n; ++i) a[i] = 1ll * a[i] * b[i] % mod;
NTT(a, n, -);
for (int i = ; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
for (int i = ; i < m - ; ++i) g[i] = (a[i] + a[i + m - ]) % mod;
}
void solve(int b) {
inv = ksm(n, mod - , mod);
g[] = ;
while (b) {
if (b & ) mul(g, f);
b >>= ;
mul(f, f);
}
}
int main() {
int cnt = read(); m = read(); int x = read(), s = read();
for (int i = ; i <= s; ++i) vis[read()] = ;
int q = Calc(m), pos = -, L = ;
for (int i = , j = ; i < m - ; ++i, j = 1ll * j * q % m) {
if (vis[j]) f[i] = ;
if (j == x) pos = i;
}
int M = (m - ) * ;
while (n < M) n <<= , L ++;
for (int i = ; i < n; ++i) rev[i] = (rev[i >> ] >> ) | ((i & ) << (L - ));
solve(cnt);
if (pos != -) cout << g[pos] % mod;
else cout << ;
return ;
}