Skip to content

快速傅里叶变换(FFT)

2023-12-03

曾经改变美苏冷战格局的算法。

快速傅里叶变换能在 O(nlogn) 的时间复杂度下计算两个 n 次多项式的乘法。

多项式的表示法

系数表示法

对于 n1 次多项式

A(x)=a0+a1x+a2x2++an1xn1

系数表示法用 n 个系数表示它。

{a0,a1,,an1}

点值表示法

两点确定 kx+b,三点确定 ax2+bx+c,以此类推。至少 n 点才能确定 n1 次多项式。

点值表示法用至少 n 个不同的点表示 n1 次多项式。

{(x0,y0),(x1,y1),,(xn1,yn1)}

多项式乘法

离散傅里叶变换(Discrete Fourier Transform,DFT)是将多项式从 系数表示法 转换为 点值表示法 的算法,IDFT 是 DFT 的逆过程。

运用 DFT 和 IDFT 计算多项式乘法 A(x)B(x)=C(x) 的主要步骤:

  1. A(x)B(x) 使用 DFT,得到两个点集:{(xi,A(xi))0i<N}{(xi,B(xi))0i<N}
  2. 计算出 C(x) 的点值表示法:{(xi,A(xi)B(xi))0i<N}
  3. 使用 IDFT 将其转化为系数表示法。

其中 NC(x) 的次数+1=A(x) 的次数+B(x) 的次数+1,否则第二步得到的点太少,不足以确定最终的答案。

一般的 DFT 和 IDFT 时间复杂度高达 O(n2),而快速傅里叶变换(Fast Fourier Transform,FFT)可以把这个过程优化到 O(nlogn)

让我们从一个例子入手,先熟悉一下 O(n2) 的普通算法是什么样的。

普通的算法

现在计算 A(x)=x2+x+1B(x)=x23 的乘积 C(x)

这个乘法的结果肯定是一个 4 次的多项式,这意味着我们至少要取 5 个点,那么就随便取 x=1,2,3,4,5 好了。

第一步:计算 A(x)B(x) 的点值表示法:

A(x){(1,3),(2,7),(3,13),(4,21),(5,31)}B(x){(1,2),(2,1),(3,6),(4,13),(5,22)}

第二步:计算 C(x)=A(x)B(x) 的点值表示法:

C(x){(1,6),(2,7),(3,78),(4,273),(5,682)}

第三步:转化为系数表示法。这一步的方法有很多,可以使用拉格朗日插值法等。总之最后算出来的结果是

C(x)=x4+x32x23x3

C(x)n1 次多项式时,需要取 n 个不同的 x。其中每次计算 A(x)B(x) 的时间复杂度为 O(n),并且这个过程要重复 n 次,那么总的时间复杂度为 O(n2)


上述流程存在一个比较现实的问题:在实际的应用场景中,多项式的次数往往很大。如果我们随意地取 x 的值,计算出的 A(x) 可能会超出基本变量类型的范围,那么这个算法很有可能在第一步就会搁浅。

所以 x 的取值其实很有门道。它既不能让 A(x) 太大,让计算机存不下;也不能让 A(x) 太小,导致计算过程中发生精度的损失。这么看来,似乎只有 011 这三个数比较合适。

但是只有这三个数是远远不够的。去哪里找其它的数呢?

这样的数,数学家们在复数域中找到了无穷多个。

单位复根

复习

形如 a+biab 均为实数)的数为复数。其中

  • a 被称为实部;
  • b 被称为虚部;
  • i 是虚数单位,i2=1
  • a2+b2 是这个复数的模。

在复平面上,a+bi 对应的坐标为 (a,b)。其中

  • a 表示的是复平面内的横坐标;
  • b 表示的是复平面内的纵坐标;
  • 表示实数 a 的点都在 x 轴上,所以 x 轴又称为「实轴」;
  • 表示纯虚数 bi 的点都在 y 轴上,所以 y 轴又称为「虚轴」。

如图 1,在复平面上画一个半径为 1 的单位圆。圆上的每一点 (cosθ,sinθ) 都可以表示复数 cosθ+isinθ,其中 θ 是幅角,即它和原点的连线与实轴正半轴的夹角。

如果把圆周角 N 等分,也就是令 θ=2πN,那么这个复数就被称作 单位复根,记作 ωN

ωN=cos2πN+isin2πN

中学课本告诉了我们复数乘法的规律:幅角相加模相乘。我们知道 ωN 的模为 1,那么 ωNk 的模也就还是 1,且其幅角从原来的 2πN 变成了 2kπN

根据上述规律,不难发现,ωN0,ωN1,ωN2,,ωNN1 在单位圆上的分布是均匀的。图 2 展示了 N=8 时的情况。

width=280pxwidth=280px
图 1图 2

单位复根还具有如下优异的性质:

  1. 周期性:ωNN=1
  2. 消去性:ω2N2k=ωNk
  3. 对称性:ωNk+N2=ωNk

证明并不困难。一方面可以用上文中 ωNk 的图像性质去推理;另一方面,直接套用欧拉公式

eiθ=cosθ+isinθωN=e2πNi

也能能轻易得证。


单位复根 恰好可以完美地解决 DFT 中取点的问题。代入 x=ωNk 既不会使 y=A(x) 大到溢出,也不会使其小到失真,唯一别扭的地方就是 xy 都是复数值。不过大部分编程语言都有支持复数运算的库,所以这不是大问题。因此当我们需要在 DFT 中取 N 个点时,不妨取

x=ωN0,ωN1,ωN2,,ωNN1

但是,单位复根仅解决了精度问题。目前为止,时间复杂度仍然是 O(n2)。真正使其成为「快速」傅里叶变换的,是接下来的「多项式分治」。

多项式分治

对一个 n 项的多项式

A(x)=a0+a1x+a2x2+a3x3++an2xn2+an1xn1

进行如下变换(假定 n=2k,kZ):

  1. 将偶数项留在前面,将奇数项移到后面:
A(x)=a0+a2x2++an2xn2+a1x+a3x3++an1xn1
  1. 对后一半提取公因式 x
A(x)=a0+a2x2++an2xn2+x(a1+a3x2++an1xn2)

Aeven(x)=a0+a2x++an2xn21Aodd(x)=a1+a3x++an1xn21

A(x)=Aeven(x2)+xAodd(x2)

注意这里我们将 x2 作为 AevenAodd 的变量。

x=ωnk 代入得:

A(ωnk)=Aeven(ωn2k)+ωnkAodd(ωn2k)=Aeven(ωn/2k)+ωnkAodd(ωn/2k)

x=ωnk+n/2=ωnk 代入得:

A(ωnk+n/2)=A(ωnk)=Aeven(ωn/2k)ωnkAodd(ωn/2k)

我们发现,Aeven(ωn/2k)Aodd(ωn/2k) 各自又都是 n/2 项多项式,它们也可以用同样的方法再往下拆分成更短的多项式之和。所以我们采用递归的方式去实现这个计算过程。

对每个 A(ωnk) 单独递归的效率太低。这里采用的递归策略是:先计算出

{Aeven(ωn/20),Aeven(ωn/21),,Aeven(ωn/2n/21)}{Aodd(ωn/20),Aodd(ωn/21),,Aodd(ωn/2n/21)}

再根据之前推出的公式

{A(ωnk)=Aeven(ωn/2k)+ωnkAodd(ωn/2k)A(ωnk+n/2)=Aeven(ωn/2k)ωnkAodd(ωn/2k)k=0,1,,n21

得出

{A(ωn0),A(ωn1),,A(ωnn1)}

进而得出我们所需要的 n 个点值。


a={a0,a1,,an1} 是存放 A(x) 系数的数组,函数 DFT(a) 的功能是算出并返回

{A(ωn0),A(ωn1),,A(ωnn1)}

以下是 DFT(a) 的伪代码:

DFT(a):n:=a.size()If n=1:return ayeven:=DFT({a0,a2,,an2})yodd :=DFT({a1,a3,,an1})y:={}ω:=1ωn:=cos2πn+isin2πnFor k=0n21:yk=ykeven+ωykoddyk+n2=ykevenωykoddω=ωωnreturn y

可以画出 DFT 算法的递归图:

容易看出,以上分治算法每次都能递归地将规模为 n 的问题拆分成两个规模为 n/2 的问题,总时间复杂度为 O(nlogn)


之所以在一开始假定 n=2k,kZ,就是为了确保每次对多项式进行拆分时都能恰好平均分为两半。如果多项式的项数 n 不符合这个要求,那么我们可以往后面补 0 项,直到达成这个要求。

例如对于 A(x)=1+x+4x2+5x3+x4+4x5 一共只有 6 项,还差 2 项就能符合要求,所以往后面补两个 0 项:

A(x)=1+x+4x2+5x3+x4+4x5+0x6+0x7

以上就是 FFT 加速算法的全部内容。不过我们只是用它加速了 DFT 的过程。多项式乘法的最后一步,也就是 IDFT,我们仍然没有提及。实际上 IDFT 也可以用同样的算法进行加速。

IDFT

前文所论述的 DFT 算法实际上是在解以下方程组:

{a0+a1+a2++an1=y0a0+a1(ωn)+a2(ωn)2++an1(ωn)n1=y1a0+a1(ωn2)+a2(ωn2)2++an1(ωn2)n1=y2a0+a1(ωnn1)+a2(ωnn1)2++an1(ωnn1)n1=yn1

A=(a0a1a2an1)Y=(y0y1y2yn1),则上述方程组可以写成矩阵乘法的形式:

(11111ωn(ωn)2(ωn)n11ωn2(ωn2)2(ωn2)n11ωnn1(ωnn1)2(ωnn1)n1)A=Y

DFT 的数学本质就是已知 A,求解 Y。而 IDFT 作为其逆过程,实际上就是已知 Y,求解 A。这正好对应了从点值表示法向系数表示法的转换。

为了实现 IDFT,我们可以很自然地对上式做出以下变换:

A=(11111ωn(ωn)2(ωn)n11ωn2(ωn2)2(ωn2)n11ωnn1(ωnn1)2(ωnn1)n1)1Y

那怎么求中间的这个逆矩阵呢?我们一眼就可以盯真出这是个特殊范德蒙矩阵的逆矩阵。

范德蒙矩阵求逆的特殊情况

在范德蒙矩阵中,当 x1=x2==xn=x 时,有

(11111xx2xn11x2x4x2(n1)1xn1x2(n1)x(n1)(n1))1=1n(11111xn1x2(n1)x(n1)(n1)1x(n2)x2(n2)x(n2)(n1)1xx2xn1)

再结合 单位复根 的性质就可以得出

(11111ωn(ωn)2(ωn)n11ωn2(ωn2)2(ωn2)n11ωnn1(ωnn1)2(ωnn1)n1)1=1n(11111ωn1(ωn1)2(ωn1)n11ωn2(ωn2)2(ωn2)n11ωn(n1)(ωn(n1))2(ωn(n1))n1)

再将这个结果代入进去,并还原成方程组:

{y0+y1+y2++yn1=na0y0+y1(ωn1)+y2(ωn1)2++yn1(ωn1)n1=na1y0+y1(ωn2)+y2(ωn2)2++yn1(ωn2)n1=na2y0+y1(ωn(n1))+y2(ωn(n1))2++y(n1)(ωn(n1))n1=nan1

这个方程组和原先的方程组极为相似。这意味着我们只需调整 DFT 代码中某些参数的正负号,并且将最终的结果除以 n,就能得到 IDFT 的代码。

模板

cpp
#include <bits/stdc++.h>
using namespace std;

typedef complex<double> Comp;
const double PI = acos(-1);

vector<Comp> DFT(vector<Comp> a, bool invert) {
    int n = a.size();
    if (n == 1) return a;

    vector<Comp> a0(n / 2), a1(n / 2);
    for (int i = 0; 2 * i < n; i ++) {
        a0[i] = a[2*i];
        a1[i] = a[2*i + 1];
    }

    vector<Comp> y0 = DFT(a0, invert);
    vector<Comp> y1 = DFT(a1, invert);
    vector<Comp> y(n);

    double angle = 2 * PI / n * (invert ? -1 : 1);
    Comp w(1), wn(cos(angle), sin(angle));
    
    for (int i = 0; i < n / 2; i++) {
        y[i] = y0[i] + w * y1[i];
        y[i + n/2] = y0[i] - w * y1[i];
        if (invert) {
            y[i] /= 2;
            y[i + n/2] /= 2;
        }
        w *= wn;
    }
    
    return y;
}

vector<Comp> multiply(vector<Comp> A, vector<Comp> B) {
    int n = 1;
    while (n < A.size() + B.size()) 
        n *= 2;
    A.resize(n);
    B.resize(n);

    vector<Comp> yA = DFT(A, false);
    vector<Comp> yB = DFT(B, false);
    vector<Comp> yC(n);

    for (int i = 0; i < n; i ++)
        yC[i] = yA[i] * yB[i];

    vector<Comp> C = DFT(yC, true);

    for (int i = 0; i < n; i ++)
        C[i] = round(C[i].real());

    while (C.size() && ! C.back().real())
        C.pop_back();

    return C;
}

int main() {
    vector<Comp> A = {1, 2, 3}; // Represents the polynomial 1 + 2x + 3x^2
    vector<Comp> B = {4, 5};    // Represents the polynomial 4 + 5x
    
    vector<Comp> C = multiply(A, B);
    
    for (auto i : C)
        cout << i.real() << ' ';
    cout << endl;
    
    return 0;
}