算法学习笔记:NTT

时间:2022-12-26 10:51:40

算法学习笔记

N T T ( N u m b e r T h e o r e t i c T r a n s f o r m s ) ,即为日本电报电话公司快速数论变换。
快速数论变换 ( N T T ) 与快速傅里叶变换 ( F F T ) 实际上相类,可以说两者拥有相同的基础思想。 N T T 有较强的针对性,需要进行变换的数据拥有原根,不过因为 F F T 应用单位复根导致大量浮点操作,经常溢出,所以在部分情况(拥有原根)下运用 N T T 解决能得到更精确的答案的。

首先,学习 N T T 要先学习 F F T

学习 F F T 可以参考 F F T

接着, N T T 要用数学方法转化 F F T

N T T 是魔芋下的 F F T ,这样 N T T 才能不用复数所以没有精度误差。那么,我们需要找到一个 g n 来替换 ω n
g n 需要满足两个性质:
1. g n n 1 ( m o d p )
2. i , j , g n i g n j ( m o d p )
因为 g p 1 1 ( m o d p ) ,所以可以使得 g n = g p 1 n m o d p
因为 n = 2 k ,所以需要使得 p = a 2 k + 1
一般来说,会取 p = 998244353 = 119 2 23 + 1 ,此时会取 g = 3
一般来说,还会取 p = 1004535809 = 479 2 21 + 1 ,此时也会取 g = 3


l u o g u 3803


题目

3803 F F T


标程

#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int N = 1 << 21;
const int M = 21;
const int G = 3;
int power(int a, int b)
{
    int c = 1;
    for (; b; b >>= 1, a = 1LL * a * a % MOD)
        if (b & 1) c = 1LL * c * a % MOD;
    return c;
}
int n, m, l, inv;
int c[N], d[N], g[22], f[22];
void ntt(int *a, int p)
{
    for (int i = 0; i < n; i++)
        if (i < d[i]) swap(a[i], a[d[i]]);
    for (int i = 0, i1 = 1, i2 = 2; i1 < n; i++, i1 <<= 1, i2 <<= 1)
    {
        int u = p == 1 ? g[i + 1] : f[i + 1];
        for (int j = 0; j < n; j += i2)
        {
            for (int k = 0, v = 1; k < i1; k++, v = 1LL * v * u % MOD)
            {
                int x = a[j + k], y = 1LL * v * a[j + k + i1] % MOD;
                a[j + k] = x + y; if (a[j + k] >= MOD) a[j + k] -= MOD;
                a[j + k + i1] = x - y; if (a[j + k + i1] < 0) a[j + k + i1] += MOD;
            }
        }
    }
    if (p == -1)
        for (int i = 0; i < n; i++)
            a[i] = 1LL * a[i] * inv % MOD;
}
int main()
{
    ios::sync_with_stdio(false);
    g[M] = power(G, (MOD - 1) / N);
    f[M] = power(g[M], MOD - 2);
    for (int i = M - 1; i; i--)
    {
        g[i] = 1LL * g[i + 1] * g[i + 1] % MOD;
        f[i] = 1LL * f[i + 1] * f[i + 1] % MOD;
    }
    static int a[N], b[N];
    cin >> n >> m;
    for (int i = 0; i <= n; i++)
        cin >> a[i];
    for (int i = 0; i <= m; i++)
        cin >> b[i];
    m += n; for (n = 1; n <= m; n <<= 1) l++;
    for (int i = 0; i < n; i++)
        d[i] = d[i >> 1] >> 1 | (i & 1) << l - 1;
    inv = power(n, MOD - 2);
    ntt(a, 1); ntt(b, 1);
    for (int i = 0; i < n; i++)
        a[i] = 1LL * a[i] * b[i] % MOD;
    ntt(a, -1);
    for (int i = 0; i <= m; i++)
        cout << a[i] << ' ';
    cout << endl;
    return 0;
}