Skip to content

快速数论变换(FNTT)

2023-12-23

快速傅里叶变换的变体。

快速傅里叶变换中,单位复根用于多项式的分治过程,作为变量 x 的取值。然而,使用单位复根有几个缺点:

  1. 单位复根定义在三角函数之上:ωN=cos2πN+isin2πN在运算过程中,会频繁地跟浮点复数 complex<double> 打交道,导致精度损失;
  2. 不适用于模环境(复数无法取模);
  3. 算法常数较大,影响效率。

幸运的是,在整数域中,我们找到了可替代单位复根的数值,它们仅在模环境中有效。这就是快速数论变换(Fast Number-Theoretic Transform, FNTT)的核心。

原根

对于质数 p,如果存在整数 g 使

g0modp,g1modp,g2modp,,gp2modp

p1 个数互不相同,则称 gp原根


根据费马小定理:

对任意质数 pP 和任意整数 aZ,都有 ap11(modp)

因此 gp11=g0(modp)。进一步地:

gpg1(modp)gp+1g2(modp)gp+2g3(modp)

这意味着,计算 g 的更高幂次会重复之前的模 p 序列,形成一个轮回。

我们可以将此轮回用环状结构表示。

模意义下原根的轮回单位复根的轮回

是不是跟单位复根如出一辙。

基于原根的该性质,我们可以构造与单位复根同构的表达式。

现对质数 p 作进一步的约束:令 p 满足 p=1+N 的形式,其中 N=2ξ2 的幂次。从而构造

GN=gp1Nmodp

可以发现,GN 具有单位复根 ωN 的所有性质:

  1. 周期性GNN=1

    证明:GNN=gp1modp=1

  2. 消去性G2N2k=GNk

    证明:G2N2k=gp12N2kmodp=gp1Nkmodp=GNk

  3. 对称性GNk+N/2=GNk

    证明:

    GNk+N/2=gp1N(k+N2)modp=gp1Nkgp12modp=GNkgp12modp

    由于 (gp12)2=gp11(modp),故 gp12±1(modp)

    根据原根的定义,g0,g1,,gp2 在模 p 意义下互不相同,因此 gp12g0=1(modp),即 gp12modp 只能是 1。代入得 GNk+N/2=GNk

因此 GN 可以完全替代单位复根,参与 FFT 的运算过程。

原根表

质数 p原根 g质数 p原根 g
321228911
52409613
173655373
97578643310
193557671693
257373400333
768117230686733
10485760131677721613
46976204939982443533
10045358093201326592131
2281701377332212254735
751619276813773094113297
2061584302092220615843020817
2748779069441365970697666575
395824185999375791648371998735
263882790666241712314530231091203
1337006139375610337999121855938505
42221246506598401978812993478983606
3152519739159340031801439850948190006
1945555039024050000541793404541998200003

注:一个质数可能对应多个原根。例如 998244353 的原根可以是 3,也可以是 114514

模板

cpp
#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const LL Mod = 998244353;
const LL G = 114514;
const LL Gi = 137043501;

LL powmod(LL a, LL b) {
    LL res = 1;
    a %= Mod;
    while (b > 0) {
        if (b % 2 == 1)
            res = res * a % Mod;
        a = a * a % Mod;
        b /= 2;
    }
    return res;
}

vector<LL> NTT(vector<LL> a, bool invert) {
    int n = a.size();
    if (n == 1) return a;

    vector<LL> a0(n / 2), a1(n / 2);
    for (int i = 0; i < n / 2; i ++) {
        a0[i] = a[2 * i];
        a1[i] = a[2 * i + 1];
    }

    vector<LL> y0 = NTT(a0, invert);
    vector<LL> y1 = NTT(a1, invert);
    vector<LL> y(n);

    LL w = 1;
    LL root = powmod(!invert ? Gi : G, (Mod - 1) / n);

    for (int i = 0; i < n / 2; i ++) {
        LL u = y0[i];
        LL v = y1[i] * w % Mod;
        y[i] = (u + v) % Mod;
        y[i + n / 2] = (u - v + Mod) % Mod;
        w = w * root % Mod;
    }

    return y;
}

vector<LL> multiply(vector<LL> A, vector<LL> B) {
    int n = 1;
    while (n < A.size() + B.size()) 
        n *= 2;

    A.resize(n);
    B.resize(n);

    vector<LL> yA = NTT(A, false);
    vector<LL> yB = NTT(B, false);
    vector<LL> yC(n);

    for (int i = 0; i < n; i ++)
        yC[i] = yA[i] * yB[i] % Mod;

    vector<LL> C = NTT(yC, true);

    LL inv_n = powmod(n, Mod - 2);
    for (LL &x : C)
        x = x * inv_n % Mod;

    while (C.size() && ! C.back())
        C.pop_back();

    return C;
}

int main() {
    vector<LL> A = {1, 2, 3};
    vector<LL> B = {4, 5, 6};

    vector<LL> res = multiply(A, B);

    for (LL x : res)
        cout << x << " ";
    cout << endl;

    return 0;
}