Skip to content

树状数组

2021-04-05

问题

数组 A 中共 n 个元素,对其反复进行以下操作共 m 次:

  • 单点修改:将 A[x] 加上 k
  • 区间查询:查询 A[l]A[r] 的和。
cpp
int a[];

// 单点修改
void set(int id, int v) {
    a[id] = v;
}

// 区间查询
int ask(int l, int r) {
    int ans = 0;
    for(int i = l; i <= r; i ++)
        ans += a[i];
    return ans;
}
暴力算法(×树状数组(
单点修改O(1)O(logn)
区间查询O(n)O(logn)
m 次操作O(mn)O(mlogn)

构造

在原数组的上方构建树型结构,每个节点表示一段区间和:

  • C1=00A1
  • C2=00A1+00A2
  • C3=00A3
  • C4=00A1+00A2+00A3+00A4

我们不用立即知道每个 Ci 代表的是哪一段的和。

初始时,将所有节点值设置成 0,就可以直接进行下文的 单点修改区间查询

父节点

如何求 Ci 的父节点?

C 数组的下标转换成二进制数,观察该图。

  • C[0001] 的父节点为 C[0010]
  • C[0100] 的父节点为 C[1000]
  • C[0101] 的父节点为 C[0110]

总结出规律:

左邻节点

Ci 的「左邻节点」与 Ci 的左端相邻。例如 C5 的左邻节点为 C4

观察同一张图:

  • C[0011] 的左邻节点为 C[0010]
  • C[0101] 的左邻节点为 C[0100]
  • C[0110] 的左邻节点为 C[0100]

总结规律:

  • Ci 的左邻节点为 C[i00lowbit(i)]

单点修改

A[x] 增加 kA[x] 的所有祖先都会跟着变动。以 A3 为例:

  • C3=00A3
  • C4=00A1+00A2+00A3+00A4
  • C8=00A1+00A2+00A3+00A4+00A5+00A6+00A7+00A8

因此,修改 A[3] 的同时,C3,00C4,00C8 也需要加上 k

对于给定的 x,从 Cx 开始逐层访问 父节点,并给其值加上 k

时间复杂度为 O(logn)

cpp
int lowbit(int x) {
    return x & -x;
}

// 将 a[p] 增加 k
void add(int p, int k) {
    for(; p <= n; p += lowbit(p)) c[p] += k;
}

区间查询

参考 前缀和 思想:设 sum[x]=00A[1]+00A[2]+00+00A[x],则:

i=lrA[i]=sum[r]sum[l1]

自此,问题转换为求任意的 sum[x]

sum[7] 为例:

  • C4=00A1+00A2+00A3+00A4
  • C6=00A5+00A6
  • C7=00A7

因而 sum[7]=00C4+00C6+00C7

查询 sum[x] 时,从 Cx 开始依次遍历 左邻节点,累加其值。

时间复杂度为 O(logn)

cpp
// 查询 A[1...p] 的和
int ask(int p) {
    int sum = 0;
    for (; p; p -= lowbit(p)) sum += c[p];
    return sum;
}

// 查询 A[l] ... A[r] 的和
int get(int l, int r) {
    return ask(r) - ask(l - 1);
}

模板

cpp
int lowbit(int x) {
    return x & -x;
}

void add(int p, int k) {
    for (; p <= n; p += lowbit(p)) c[p] += k;
}

int ask(int p) {
    int sum = 0;
    for (; p; p -= lowbit(p)) sum += c[p];
    return sum;
}

int get(int l, int r) {
    return ask(r) - ask(l - 1);
}

拓展

一般的树状数组只能实现「单点修改区间查询」。

通过 差分 技巧,树状数组也可以实现「区间修改单点查询」。

  1. 在数组 A 的差分数组 f 上建立树状数组,其中 f[i]=00A[i]00A[i001]
  2. A[l]A[r] 都加上 v 时,f[l] 增加了 vf[r+001] 减少了 v

由于 A[i]=00f[1]+00f[2]+00+00f[i],所以可以通过区间查询得到 A[i] 的值。

cpp
// 将 A[l] ... A[r] 加上 v
void seg_add(int l, int r, int v) {
    add(l, v);
    add(r + 1, -v);
}

// 查询 A[p] 的值
int ask(int p) {
    int sum = 0;
    for (; p; p -= lowbit(p)) sum += c[p];
    return sum;
}

int main() {
    
    /* 输入部分省略 */
    
    for (int i = 1; i <= n; i ++) {
        // 在差分数组上建立树状数组
        add(i, a[i] - a[i - 1]);
    }
}