快速沃尔什变换(FWT)

一个很邪门的算法。

二进制卷积 #

快速傅里叶变换 (FFT)可以高效地计算两个多项式的乘积,进而加速了卷积运算。

info

序列 $A$ 和序列 $B$ 的卷积记作 $C=A*B$,其中

$$C_k=\sum_{i+j=k}A_iB_j$$

构造多项式

$$A(x)=A_0+A_1x+A_2x^2+\cdots\\ B(x)=B_0+B_1x+B_2x^2+\cdots$$

我们知道 $A(x)\cdot B(x)$ 中 $x^k$ 前的系数就是 $C_k$,而 $A(x)\cdot B(x)$ 的系数序列可由 FFT 得到。因此 FFT 也被认为是快速卷积算法。

现在考虑一类奇怪的卷积:

$$\DeclareMathOperator{\or}{or} \DeclareMathOperator{\and}{and} \DeclareMathOperator{\xor}{xor} \sum_{i \or j=k}A_iB_j \qquad \sum_{i \and j=k}A_iB_j \qquad \sum_{i \xor j=k}A_iB_j$$

其中 $\or,\and,\xor$ 是二进制按位或、与、异或。右(下)图附上了一位二进制运算的真值表。

面对这类奇葩的卷积时,FFT 就束手无策了。

$A$ $B$ $A\or B$ $A\and B$ $A\xor B$
0 0 0 0 0
0 1 1 0 1
1 0 1 0 1
1 1 1 1 0

还有高手? #

快速沃尔什变换(Fast Walsh-Hadamard Transform,FWT)是加速 二进制卷积 运算的算法,其大方向与 FFT 非常相似。

回忆一下 FFT 如何加速多项式乘法:

FFT
  1. 对 $A(x)$ 和 $B(x)$ 使用 DFT,得到两个点集:

    $$\left\{\big(x_i,A(x_i)\big)\mid 0\le i< N\right\} \\ \left\{\big(x_i,B(x_i)\big)\mid 0\le i< N\right\}$$

  2. 计算出 $C(x)=A(x)\cdot B(x)$ 的点值表示法:

    $$\left\{\big(x_i,A(x_i)\cdot B(x_i)\big)\mid 0\le i< N\right\}$$

  3. 使用 IDFT 将其转化为系数表示法。

我们在 DFTIDFT 中使用 FFT 的思想来加速这个过程。

那么其实 FWT 用到的也是类似的思路:

FWT
  1. 对序列 $A,B$ 使用 FWT,得到两个中间态(中间态也是序列):

    $$\text{FWT}(A),\text{FWT}(B)$$

  2. 使用点对点乘法,计算出 $C$ 的中间态:

    $$\text{FWT}(C)=\left\{\text{FWT}(A)_i\times\text{FWT}(B)_i \ \Big| \ i=0,1,2,\cdots\right\}$$

  3. 使用 UFWT(也称 IFWT)将其还原为序列 $C$。

只不过 FWT 的中间态略显抽象,并且 $\or,\and,\xor$ 卷积所对应的中间态还各不相同。

OR 卷积 #

在 $\or$ 卷积中,序列 $X$ 对应的中间态为与之等长的序列 $\text{FWT}(X)$,其中

$$\text{FWT}(X)_k=\sum_{i\or k=k}X_i$$

🤔沃尔什究竟是怎么想出来的?

设有两个序列 $A,B$,待求的序列是 $C$,其中 $C_k=\sum_{i \or j=k}A_iB_j$。

可以证明,对 $A$ 和 $B$ 的中间态进行点对点乘法可以得出 $C$ 的中间态。

proof

容易发现

$$a\or c=c,b\or c=c\intro (a\or b)\or c=c$$

可以从集合意义理解:若 $a\in c$ 且 $b\in c$,则 $a$ 和 $c$ 的并集也是 $c$ 的子集。

接下来的证明步骤会用到该结论。

$$\begin{aligned} & \ \text{FWT}(A)_ k\times\text{FWT}(B)_ k\\ =& \ \left(\sum_{i\or k=k}A_i\right)\left(\sum_{j\or k=k}B_j\right)\\ =& \ \sum_{i\or k=k,j\or k=k}A_iB_j\\ =& \ \sum_{(i\or j)\or k=k}A_iB_j\\ =&\!\!\!\overset{x=i\or j}{=\!=\!=\!=}\!\!\!=\sum_{x\or k=k} \ \sum_{i\or j=x}A_iB_j\\ =& \ \sum_{x\or k=k}C_x\\ =& \ \text{FWT}(C)_k\\ \end{aligned}$$

证毕。

现在推导一下中间态的求法。

类似于 FFT,此处也规定所有序列的长度 $n$ 都是 $2$ 的幂次。

我们将序列 $A=\{A_0,A_1,\cdots,A_{n-1}\}$ 均分为两段

$$A_\text{prev}=\left\{A_0,A_1,\cdots,A_{\frac{n}{2}-1}\right\} \quad A_\text{back}=\left\{A_\frac{n}{2},A_{\frac{n}{2}+1},\cdots,A_{n-1}\right\}$$

然后仔细盯真 $A$ 的中间态:

$$\text{FWT}(A)_ k=\sum_{i\or k=k}A_i$$

  • 当 $\textstyle 0\le k\le \frac{n}{2}-1$ 时,由 $\or$ 运算的性质可知,$i\or k=k$ 中 $i$ 的大小不可能超过 $k$。因此 $i$ 也落在 $\textstyle 0\cdots \frac{n}{2}-1$ 的范围内,也就是说 $A_i$ 只能来自 $A_\text{prev}$,即 $\sum_{i\or k=k}A_i=\sum_{i\or k=k}(A_\text{prev})_i$,即

    $$\text{FWT}(A)_ k=\text{FWT}\left(A_\text{prev}\right)_k$$

  • 当 $\textstyle\frac{n}{2}\le k\le n-1$ 时,$A_i$ 既可以来自 $A_\text{prev}$,又可以来自 $A_\text{back}$。由于 $\text{FWT}(A)_k$ 是一个求和式,我们可以顺理成章地把它拆成两个部分的和

    $$\text{FWT}(A)_ k=\text{FWT}\left(A_\text{prev}\right)_{k-\frac{n}{2}}+\text{FWT}\left(A_\text{back}\right)_{k-\frac{n}{2}}$$

    注:此处对 $k$ 减去 $\textstyle\frac{n}{2}$ 是为了调整下标,使其适应相应的分段序列。或者说,$A_\text{prev}$ 和 $A_\text{back}$ 的长度都只有 $\textstyle\frac{n}{2}$,作此调整可以防止数组访问越界。

$$\text{FWT}(A)_k=\begin{cases} \text{FWT}\left(A_\text{prev}\right)_k,&0\le k\le \frac{n}{2}-1\\ \\ \text{FWT}\left(A_\text{prev}\right)_{k-\frac{n}{2}}+\text{FWT}\left(A_\text{back}\right)_{k-\frac{n}{2}},&\frac{n}{2}\le k\le n-1 \end{cases}$$

更进一步地,以序列为基本单位看问题:

$$\DeclareMathOperator{\merge}{merge} \text{FWT}(A)=\begin{cases} A,&n=1\\ \\ \merge\left(\text{FWT}\left(A_\text{prev}\right),\text{FWT}\left(A_\text{back}\right) \oplus \text{FWT}\left(A_\text{prev}\right)\right),&n>1 \end{cases} \tag{1}$$

其中 $\merge$ 是拼接两个序列的操作,$\oplus$ 表示两个序列的点对点加法。


上面我们已经剖析了 FWT 的算法,现在还需要讨论其反演 UFWT

其实只要把点对点加法 $\oplus$ 改为点对点减法 $\ominus$ 就可以得到 UFWT 的递推式。

$$\text{UFWT}(A)=\begin{cases} A,&n=1\\ \\ \merge\left(\text{UFWT}\left(A_\text{prev}\right),\text{UFWT}\left(A_\text{back}\right) \ominus \text{UFWT}\left(A_\text{prev}\right)\right),&n>1 \end{cases} \tag{2}$$

容易证明 $(2)$ 式是 $(1)$ 式的逆变换。

proof

此论断等价于 $\text{UFWT}\big(\text{FWT}(A)\big)=A$,这里采用数学归纳法证明。

设序列 $A$ 的长度为 $n$。当 $n=1$ 时命题显然成立。

由于 $n$ 始终是 $2$ 的幂次,即只需证明当 $n>1$ 时

$$\textstyle 命题对 \ \frac{n}{2} \ 成立 \intro 命题对 \ n \ 成立$$

现假设命题对 $\textstyle\frac{n}{2}$ 成立。由于 $A_\text{prev}$ 和 $A_\text{back}$ 都是长为 $\textstyle\frac{n}{2}$ 的序列,因此有

$$\text{UFWT}\big(\text{FWT}(A_\text{prev})\big)=A_\text{prev}, \text{UFWT}\big(\text{FWT}(A_\text{back})\big)=A_\text{back}$$

由式 $(1)$ $(2)$ 得(将 $\merge$ 拆开):

$$\begin{aligned} &\begin{cases} \text{FWT}(A)_\text{prev}=\text{FWT}(A_\text{prev}) \\ \text{FWT}(A)_\text{back}=\text{FWT}(A_\text{back})\oplus\text{FWT}(A_\text{prev}) \end{cases} \\ \\ &\begin{cases} \text{UFWT}(A)_\text{prev}=\text{UFWT}(A_\text{prev}) \\ \text{UFWT}(A)_\text{back}=\text{UFWT}(A_\text{back})\ominus\text{UFWT}(A_\text{prev}) \end{cases} \end{aligned}$$

因此

$$\begin{aligned} \text{UFWT}\big(\text{FWT}(A)\big)_\text{prev}&=\text{UFWT}\big(\text{FWT}(A)_\text{prev}\big)=\text{UFWT}\big(\text{FWT}(A_\text{prev})\big)=A_\text{prev}\\ \\ \text{UFWT}\big(\text{FWT}(A)\big)_\text{back}&=\text{UFWT}\big(\text{FWT}(A)_\text{back}\big)\ominus\text{UFWT}\big(\text{FWT}(A)_\text{prev}\big)\\ &=\text{UFWT}\big(\text{FWT}(A_\text{back}) \oplus \text{FWT}(A_\text{prev})\big)\ominus \text{UFWT}\big(\text{FWT}(A_\text{prev})\big)\\ &=\text{UFWT}\big(\text{FWT}(A_\text{back})\big) \oplus \text{UFWT}\big(\text{FWT}(A_\text{prev})\big)\ominus \text{UFWT}\big(\text{FWT}(A_\text{prev})\big)\\ &=A_\text{back}\oplus A_\text{prev}\ominus A_\text{prev}\\ &=A_\text{back} \end{aligned}$$

因此

$$\text{UFWT}\big(\text{FWT}(A)\big)=\merge(A_\text{prev},A_\text{back})=A$$

即命题对 $n$ 成立。证毕。

模板 #

void FWT_OR(vector<int>& a, int l, int r, bool inv) {
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    FWT_OR(a, l, mid, inv);
    FWT_OR(a, mid + 1, r, inv);
    for (int i = 0; i <= mid - l; i ++) {
        if (!inv)
            a[mid + 1 + i] += a[l + i];
        else
            a[mid + 1 + i] -= a[l + i];
    }
}
void FWT_OR(vector<int>& a, bool inv) {
    int n = a.size();
    for (int d = 1; d < n; d <<= 1) {
        for (int m = d << 1, i = 0; i < n; i += m) {
            for (int j = 0; j < d; j ++) {
                if (!inv)
                    a[i + j + d] += a[i + j];
                else
                    a[i + j + d] -= a[i + j];
            }
        }
    }
}

AND 卷积 #

$\and$ 卷积和 $\or$ 卷积在本质上是类似的。

$$\text{FWT}(X)_k=\sum_{i\and k=k}X_i$$

类似可证 $\text{FWT}(A)_k\times\text{FWT}(B)_k=\text{FWT}(C)_k$。

由于 $\and$ 运算的性质,递推式的点对点加减法操作都集中在前面。

$$\begin{aligned} \text{FWT}(A)&=\begin{cases} A,&n=1\\ \\ \merge\left(\text{FWT}\left(A_\text{prev}\right) \oplus \text{FWT}\left(A_\text{back}\right),\text{FWT}\left(A_\text{back}\right)\right),&n>1 \end{cases} \newline \newline \text{UFWT}(A)&=\begin{cases} A,&n=1\\ \\ \merge\left(\text{UFWT}\left(A_\text{prev}\right) \ominus \text{UFWT}\left(A_\text{back}\right),\text{UFWT}\left(A_\text{back}\right)\right),&n>1 \end{cases} \end{aligned}$$

模板 #

void FWT_AND(vector<int>& a, int l, int r, bool inv) {
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    FWT_AND(a, l, mid, inv);
    FWT_AND(a, mid + 1, r, inv);
    for (int i = 0; i <= mid - l; i ++) {
        if (!inv)
            a[l + i] += a[mid + 1 + i];
        else
            a[l + i] -= a[mid + 1 + i];
    }
}
void FWT_AND(vector<int>& a, bool inv) {
    int n = a.size();
    for (int d = 1; d < n; d <<= 1) {
        for (int m = d << 1, i = 0; i < n; i += m) {
            for (int j = 0; j < d; ++j) {
                if (!inv)
                    a[i + j] += a[i + j + d];
                else
                    a[i + j] -= a[i + j + d];
            }
        }
    }
}

XOR 卷积 #

$\xor$ 卷积的构造相比前两个更加麻烦。

设序列 $X$(长为 $n$)的中间态为 $\text{FWT}(X)$,其中

$$\text{FWT}(X) _ k=\sum_{i=0}^{n-1}f(i,k)X_i$$

其中 $f(x,y)$ 是待确定的函数,需要通过以下过程反推出来。


要使 $\text{FWT}(A)_k\times\text{FWT}(B)_k=\text{FWT}(C)_k$,即

$$\sum_{i=0}^{n-1}f(i,k)A_i \times \sum_{j=0}^{n-1}f(j,k)B_j = \sum_{x=0}^{n-1}f(x,k)C_x$$

将 $C_x=\sum_{i\xor j=x}A_iB_j$ 代入

$$\sum_{i=0}^{n-1}f(i,k)A_i \times \sum_{j=0}^{n-1}f(j,k)B_j = \sum_{x=0}^{n-1}f(x,k) \sum_{i\xor j=x}A_iB_j$$

$$\sum_{i=0}^{n-1}\sum_{j=0}^{n-1} f(i,k) f(j,k) A_iB_j = \sum_{x=0}^{n-1} \sum_{i\xor j=x} f(x,k) A_iB_j$$

在 $0\cdots n-1$ 范围内,对于每个可能的 $i,j$ 组合,总有一个 $x$ 满足 $i\xor j=x$,所以 $\sum_{x=0}^{n-1} \sum_{i\xor j=x}$ 实际上遍历了所有的 $i$ 和 $j$,因此

$$\sum_{i=0}^{n-1}\sum_{j=0}^{n-1} f(i,k) f(j,k) A_iB_j = \sum_{i=0}^{n-1} \sum_{j=0}^{n-1} f(i\xor j,k) A_iB_j$$

即得

$$f(i,k)f(j,k)=f(i\xor j,k)\tag{i}$$


注意到

$$\DeclareMathOperator{\bitcount}{bitcount} a\xor b=c\intro(-1)^{\large\bitcount(a)}(-1)^{\large\bitcount(b)}=(-1)^{\large\bitcount(c)}\tag{ii}$$

其中 $\bitcount(x)$ 表示 $x$ 的二进制中 $1$ 的个数。

这句话实际是在说:当 $a\xor b=c$ 时 $\bitcount(a)+\bitcount(b)$ 的奇偶性和 $\bitcount(c)$ 相同。

这个结论是显然的。二进制按位异或实际上是一位一位地异或:

  • 若 $a,b$ 的当前位都是 $1$,则 $c$ 的该位为 $0$,相当于少了两个 $1$,$1$ 的个数的奇偶性没有变;
  • 若一个是 $1$,另一个是 $0$,则 $c$ 的该位为 $1$,相当于没差;
  • 若都是 $0$ 也没差。

既然奇偶性相同,那么 $(-1)^{\large\bitcount(a)+\bitcount(b)}=(-1)^{\large\bitcount(c)}$,变形可得 $(\text{ii})$ 式。

又注意到

$$(i\and k)\xor(j\and k)=(i\xor j)\and k$$

代入 $(\text{ii})$ 式得

$$(-1)^{\large\bitcount(i\and k)}(-1)^{\large\bitcount(j\and k)}=(-1)^{\large\bitcount((i\xor j)\and k)}$$

于是便可以确定函数 $f$ 的形式:

$$f(x,y)=(-1)^{\large\bitcount(x \and y)}$$

该形式完美符合 $(\text i)$ 式的结论。

从而序列 $X$ 的中间态的全貌应该是这样的:

$$\text{FWT}(X) _ k=\sum_{i=0}^{n-1}(-1)^{\large\bitcount(i\and k)}X_i$$


现对 $\text{FWT}(A)_k$ 进行分治。

  • 当 $\textstyle 0\le k\le \frac{n}{2}-1$ 时,直接把求和式拆分成两半

$$\begin{aligned} \text{FWT}(A) _ k&=\sum_{i=0}^{n-1}(-1)^{\large\bitcount(i\and k)}A_i\\ &=\sum_{i=0}^{\frac{n}{2}-1}(-1)^{\large\bitcount(i\and k)}A_i+\sum_{i=\frac{n}{2}}^{n-1}(-1)^{\large\bitcount(i\and k)}A_i\\ &=\sum_{i=0}^{\frac{n}{2}-1}(-1)^{\large\bitcount(i\and k)}(A_\text{prev})_i+\sum_{i=0}^{\frac{n}{2}-1}(-1)^{\large\bitcount(i\and k)}(A_\text{back})_i\\ &=\text{FWT}(A_\text{prev})_k+\text{FWT}(A_\text{back})_k \end{aligned}$$

  • 当 $\textstyle \frac{n}{2}\le k\le n-1$ 时,需要对下标 $k$ 进行偏移,偏移量为 $\textstyle\frac{n}{2}$,理由同 OR 卷积

需要注意的是,当 $i$ 取 $\textstyle\frac{n}{2}\cdots n-1$ 时,$\textstyle\bitcount\left(i\and \left(k-\frac{n}{2}\right)\right)=-\bitcount(i\and k)$,下标偏移导致了此处符号的改变。

$$\begin{aligned} \text{FWT}(A) _ k&=\sum_{i=0}^{n-1}(-1)^{\large\bitcount(i\and k)}A_i\\ &=\sum_{i=0}^{\frac{n}{2}-1}(-1)^{\large\bitcount\left(i\and \left(k-\frac{n}{2}\right)\right)}A_i-\sum_{i=\frac{n}{2}}^{n-1}(-1)^{\large\bitcount\left(i\and \left(k-\frac{n}{2}\right)\right)}A_i\\ &=\sum_{i=0}^{\frac{n}{2}-1}(-1)^{\large\bitcount\left(i\and \left(k-\frac{n}{2}\right)\right)}(A_\text{prev})_i-\sum_{i=0}^{\frac{n}{2}-1}(-1)^{\large\bitcount\left(i\and \left(k-\frac{n}{2}\right)\right)}(A_\text{back})_i\\ &=\text{FWT}(A_\text{prev})_{k-\frac{n}{2}}-\text{FWT}(A_\text{back})_{k-\frac{n}{2}} \end{aligned}$$

$$\text{FWT}(A)_k=\begin{cases} \text{FWT}(A_\text{prev})_k+\text{FWT}(A_\text{back})_k, &0\le k\le \frac{n}{2}-1\\ \\ \text{FWT}(A_\text{prev})_{k-\frac{n}{2}}-\text{FWT}(A_\text{back})_{k-\frac{n}{2}}, &\frac{n}{2}\le k\le n-1 \end{cases}$$

进一步地

$$\text{FWT}(A)=\begin{cases} A,&n=1\\ \merge(\text{FWT}(A_\text{prev})\oplus \text{FWT}(A_\text{back}),\text{FWT}(A_\text{prev})\ominus \text{FWT}(A_\text{back})),&n>1 \end{cases}$$

对应地

$$\text{UFWT}(A)=\begin{cases} A,&n=1\\ \merge\left(\displaystyle\frac{\text{UFWT}(A_\text{prev})\oplus \text{UFWT}(A_\text{back})}{2},\frac{\text{UFWT}(A_\text{prev})\ominus \text{UFWT}(A_\text{back})}{2}\right),&n>1 \end{cases}$$

其中对序列的除法运算也是点对点除法。

模板 #

void FWT_XOR(vector<int>& a, int l, int r, bool inv) {
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    FWT_XOR(a, l, mid, inv);
    FWT_XOR(a, mid + 1, r, inv);
    for (int i = 0; i <= mid - l; i ++) {
        int u = a[l + i], v = a[mid + 1 + i];
        a[l + i] = u + v;
        a[mid + 1 + i] = u - v;
    }
    if (inv) {
        for (int i = l; i <= r; i ++) {
            a[i] /= 2;
        }
    }
}
void FWT_XOR(vector<int>& a, bool inv) {
    int n = a.size();
    for (int d = 1; d < n; d <<= 1) {
        for (int i = 0; i < n; i += (d << 1)) {
            for (int j = 0; j < d; j ++) {
                int u = a[i + j], v = a[i + j + d];
                a[i + j] = u + v;
                a[i + j + d] = u - v;
            }
        }
    }
    if (inv) {
        for (int i = 0; i < n; i ++) {
            a[i] /= n;
        }
    }
}