[洛谷P3321][SDOI2015]序列统计

时间:2024-11-09 17:03:38

题目大意:给你一个集合$n,m,x,S(S_i\in(0,m],m\leqslant 8000,m\in \rm{prime},n\leqslant10^9)$,求一个长度为$n$的序列$Q$,满足$Q_i\in S$,且$\prod\limits _{i=1}^nQ_i=x$,求序列的个数

题解:乘比较麻烦,可以把每个数求$\ln$,可以求出$m$的原根,求原根可以暴力$O(m^2)$求,然后每个数求$\ln$,求出生成函数$F(x)$,算出$F^n(x)$。发现$n$较大,多项式快速幂即可。

卡点:

C++ Code:

#include <cstdio>
#include <algorithm>
#include <cstring>
#define maxn 16384 | 3
#define maxm 8010
const int mod = 1004535809, G = 3;
int n, m, x, S, g;
int vis[maxm];
int get(int m) {
bool find = false;
for (int i = 2; i < m; i++) {
memset(vis, -1, sizeof vis);
int t = 1;
vis[1] = 0;
for (int j = 1; j < m - 1; j++) {
t = t * i % m;
if (~vis[t]) break;
else vis[t] = j;
if (j == m - 2) find = true;
}
if (find) return i;
}
return 20040826;
}
int lim, ilim, s, rev[maxn];
int base[maxn], ans[maxn], Wn[maxn + 1];
inline int pw(int base, int p) {
int res = 1;
for (; p; p >>= 1, base = 1ll * base * base % mod) if (p & 1) res = 1ll * res * base % mod;
return res;
}
inline int Inv(int x) {return pw(x, mod - 2);}
inline void init(int n) {
lim = 1, s = -1; while (lim < n) lim <<= 1, s++; ilim = Inv(lim);
for (int i = 0; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
int t = pw(G, (mod - 1) / lim);
Wn[0] = 1; for (int i = 1; i <= lim; i++) Wn[i] = 1ll * Wn[i - 1] * t % mod;
}
inline void up(int &a, int b) {if ((a += b) >= mod) a -= mod;}
inline void NTT(int *A, int op) {
for (int i = 0; i < lim; i++) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
int t = lim / mid >> 1;
for (int i = 0; i < lim; i += mid << 1) {
for (int j = 0; j < mid; j++) {
int W = op ? Wn[j * t] : Wn[lim - j * t];
int X = A[i + j], Y = 1ll * A[i + j + mid] * W % mod;
up(A[i + j], Y), up(A[i + j + mid] = X, mod - Y);
}
}
}
if (!op) for (int i = 0; i < lim; i++) A[i] = 1ll * A[i] * ilim % mod;
}
int C[maxn], D[maxn];
inline void MUL(int *A, int *B) {
for (int i = 0; i < lim; i++) C[i] = A[i], D[i] = B[i];
NTT(C, 1), NTT(D, 1);
for (int i = 0; i < lim; i++) C[i] = 1ll * C[i] * D[i] % mod;
NTT(C, 0);
for (int i = 0; i < lim; i++) A[i] = C[i];
for (int i = m - 1; i < lim; i++) (A[i % (m - 1)] += A[i]) %= mod, A[i] = 0;
}
int main() {
scanf("%d%d%d%d", &n, &m, &x, &S);
g = get(m);
for (int i = 0, tmp; i < S; i++) {
scanf("%d", &tmp);
if (tmp) base[vis[tmp]] = 1;
}
init(m << 1);
ans[0] = 1;
for (; n; n >>= 1, MUL(base, base)) if (n & 1) MUL(ans, base);
printf("%d\n", ans[vis[x]]);
return 0;
}