Skip to content

树状数组

2021-04-05

问题

数组 A 中共 n 个元素,对其反复进行以下操作共 m 次:

  • 单点修改:将 A[x] 加上 k
  • 区间查询:查询 A[l]A[r] 的和。
cpp
int a[];

// 单点修改
void set(int id, int v) {
    a[id] = v;
}

// 区间查询
int ask(int l, int r) {
    int ans = 0;
    for(int i = l; i <= r; i ++)
        ans += a[i];
    return ans;
}
暴力算法(×树状数组(
单点修改O(1)O(logn)
区间查询O(n)O(logn)
m 次操作O(mn)O(mlogn)

构造

在原数组的上方构建树型结构,每个节点表示一段区间和:

  • C1=A1
  • C2=A1+A2
  • C3=A3
  • C4=A1+A2+A3+A4

我们不用立即知道每个 Ci 代表的是哪一段的和。

初始时,将所有节点值设置成 0,就可以直接进行下文的 单点修改区间查询

父节点

如何求 Ci 的父节点?

C 数组的下标转换成二进制数,观察该图。

  • C[0001] 的父节点为 C[0010]
  • C[0100] 的父节点为 C[1000]
  • C[0101] 的父节点为 C[0110]

总结出规律:

左邻节点

Ci 的「左邻节点」与 Ci 的左端相邻。例如 C5 的左邻节点为 C4

观察同一张图:

  • C[0011] 的左邻节点为 C[0010]
  • C[0101] 的左邻节点为 C[0100]
  • C[0110] 的左邻节点为 C[0100]

总结规律:

  • Ci 的左邻节点为 C[ilowbit(i)]

单点修改

A[x] 增加 kA[x] 的所有祖先都会跟着变动。以 A3 为例:

  • C3=A3
  • C4=A1+A2+A3+A4
  • C8=A1+A2+A3+A4+A5+A6+A7+A8

因此,修改 A[3] 的同时,C3,C4,C8 也需要加上 k

对于给定的 x,从 Cx 开始逐层访问 父节点,并给其值加上 k

时间复杂度为 O(logn)

cpp
int lowbit(int x) {
    return x & -x;
}

// 将 a[p] 增加 k
void add(int p, int k) {
    for(; p <= n; p += lowbit(p)) c[p] += k;
}

区间查询

参考 前缀和 思想:设 sum[x]=A[1]+A[2]++A[x],则:

i=lrA[i]=sum[r]sum[l1]

自此,问题转换为求任意的 sum[x]

sum[7] 为例:

  • C4=A1+A2+A3+A4
  • C6=A5+A6
  • C7=A7

因而 sum[7]=C4+C6+C7

查询 sum[x] 时,从 Cx 开始依次遍历 左邻节点,累加其值。

时间复杂度为 O(logn)

cpp
// 查询 A[1...p] 的和
int ask(int p) {
    int sum = 0;
    for (; p; p -= lowbit(p)) sum += c[p];
    return sum;
}

// 查询 A[l] ... A[r] 的和
int get(int l, int r) {
    return ask(r) - ask(l - 1);
}

模板

cpp
int lowbit(int x) {
    return x & -x;
}

void add(int p, int k) {
    for (; p <= n; p += lowbit(p)) c[p] += k;
}

int ask(int p) {
    int sum = 0;
    for (; p; p -= lowbit(p)) sum += c[p];
    return sum;
}

int get(int l, int r) {
    return ask(r) - ask(l - 1);
}

拓展

一般的树状数组只能实现「单点修改区间查询」。

通过 差分 技巧,树状数组也可以实现「区间修改单点查询」。

  1. 在数组 A 的差分数组 f 上建立树状数组,其中 f[i]=A[i]A[i1]
  2. A[l]A[r] 都加上 v 时,f[l] 增加了 vf[r+1] 减少了 v

由于 A[i]=f[1]+f[2]++f[i],所以可以通过区间查询得到 A[i] 的值。

cpp
// 将 A[l] ... A[r] 加上 v
void seg_add(int l, int r, int v) {
    add(l, v);
    add(r + 1, -v);
}

// 查询 A[p] 的值
int ask(int p) {
    int sum = 0;
    for (; p; p -= lowbit(p)) sum += c[p];
    return sum;
}

int main() {
    
    /* 输入部分省略 */
    
    for (int i = 1; i <= n; i ++) {
        // 在差分数组上建立树状数组
        add(i, a[i] - a[i - 1]);
    }
}