和其他数据结构不同,树状数组非常好写代码,并且非常好调试,但是它的原理很难理解,因为涉及到二进制,所以很抽象(因此也希望读者能够仔细阅读,仔细思考,有充足的耐心)。

这类题的难点在于读到题目时想到需要用树状数组。

树状数组能解决的问题:

  1. 快速求前缀和 $O(\log n)$​

    传统可以直接维护一个前缀和数组,这样查询区间前缀和的时候的时间复杂度为$O(1)$(但此时如果要修改某个数,就要重新维护一次前缀和数组,时间复杂度为$O(n)$)

  2. 修改某一个数 $O(\log n)$​​

    传统可以直接用一个数组,修改数组中的一个数的时间复杂度为$O(1)$(但此时如果要得到某个区间的前缀和的时间复杂度是为$O(n)$)

它属于是兼顾两种操作。当同时要求修改某一个数和查询区间前缀和的时候就需要用树状数组。

基本原理

在引出概念前一定要复习一下lowbit()

lowbit(n) 定义为非负整数 n 在二进制表示下 “最低位的1及其后面的所有的0” 的二进制构成的数值。

比如当 n = 5 的时候,5 的二进制是 :0101 , 所以有:lowbit ( 5 ) = 1

比如当 n = 10 的时候,10 的二进制是 :1010,所以有: lowbit ( 10 ) = 2

因为二进制中求负数实际上是 取反加一

假设 n > 0 ,设 n 的二进制表示中,第 k 位为 1 ,第 0 至第 k-1 位都为 0

现在我们对 n 的二进制进行取反操作,可以得到,~n 的二进制表示中,第 k 位为 0 ,第 0 至第 k-1 位都为 1,然后我们再将 ~n 进行加 1 操作,可以得到一个结果,就是 ~n+1 的第 k+1 位至其最高位都为 n 的二进制表示中相反的数字,然后我们再将 ~n+1n 进行与运算,就可以得到我们想要的结果了。又因为 ~n=-1-n ,所以 -n = ~n+1,有:

lowbit( x ) = n & ( ~n + 1 ) = n & ( -n )

C++代码实现:

1
2
3
int lowbit(int x) {
return x & -x;
}
1
#define lowbit(x) ((x)&(-x))

引出直观概念

基于二进制的想法(任意一个数可以由二进制表示,通过它的第k位取1(或取0)进行一些操作的想法)。很多算法其实都用到了这个思想,例如快速幂,倍增法求LCA。

假设

其中

这样可以分成 $k$ 个区间:

假设已经将这$k$个区间的和预处理出来了,那么求下标为1~x的和就需要$O(k)$,因为$k \le \log x$,因此时间复杂度为$O(\log x)$​。

看一下「性质」

从上到下令这些区间为1号区间,2号区间…那么1号区间中有$2^{i_1}$个数,2号区间有$2^{i_2}$个数,…k号区间有$2^{i_k}$​个数。

令区间表示为$(L, R]$,那么区间个数一定为 $R$的最后一位1所在位置减一 的次幂(最后一位1可以用lowbit(R)来求(时间复杂度为$O(1)$),个数为$2^{lowbit(R)-1}$)。那么区间就有可以表达为$(R-2^{lowbit(R)-1}, R]$

这个时候一个区间已经用一个数$R$​表示了,因此就可以用一个数组来存这个区间的总和了,例如:C[R]表示以R作为结尾的区间的总和。

此时有了c[x]的定义了,其表示$\sum\limits_{i=x-2^{lowbit(x)-1}}^R A_i$。其中$A_i$表示数组$A$的第$i$​位。

这里画出c[x]的含义:

image-20240305194004416

其中令 每个垂直向下的第一个节点 为 节点的子节点。(c[16]的子节点有c[8],c[12],c[14],A[16])

通过父节点找子节点

考虑所有c[x]之间的关系,我们可以发现:

  • 第一行:

    c[16]=A[16]+c[15]+c[14]+c[12]+c[8]

  • 第二行:

    c[8]=A[8]+c[7]+c[6]+c[4]

  • 第三行:

    c[4]=A[4]+c[3]+c[2]

    c[12]=A[12]+c[11]+c[10]

  • 第四行:

    c[2]=A[2]+c[1]

    c[6]=A[6]+c[5]

  • 第五行:

    c[1]=A[1]

从这里就可以看出树的感觉了,其中x为子树的根节点,假设x的二进制表示为:$….100…00$(其中的$1$为最后一个$1$,假设后面有$k$个$0$),根据前面的定义,c[x]的区间有$2^k$个元素。

一定是要加A[x]的,因此需要考虑如何求$[x-2^k+1, x-1]$​的和。

此时,$x-1$的二进制表示一定为$11….1$(包含$k$个$1$)。根据前面的定义,一定可以分为$k$个区间,每个区间表示为:

简单用二进制表示一下当 $x=8{10}$ 时, $x-1=7{10}=111_2$ ,区间为:

所以此时求c[8]就是求这个三个区间的和加上A[8],即c[7]+c[6]+c[4]同时此时的8,7,6,4就是8的子节点。

感觉这一段像是废话,但实际上第一次表达区间是为了引出树状数组的概念,第二次引出区间,是与二进制建立连接,最终求和是和二进制有关的。

于是得到:c[x]=A[x]+c[x-1]+c[(x-1)-lowbit(x-1)]+...。同时这也是通过父节点找子节点的方法

通过子节点找父节点

这个是比较重要的,因为它是修改操作的关键。

注意,这里找的父节点不是祖宗节点。例如7的父节点为8。

根据上面找出的区间可以知道 $111_2, \, 110_2, \, 100_2$ 是 $1000_2$ 的子节点,可以发现若子节点为 $x$ ,则其父节点为 $x + lowbit(x)$ 。

然后就是递归的方式向上修改。并且向上能够影响到的区间最多就是 $\log x$ ​。

修改操作模版

假设需要对A[i]+=c

实现代码:

1
2
3
void add(int x, int c) { // 给下标为x的元素加上c
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}

解释:最开始令i等于节点x,然后让其对应的值加c,然后i += lowbit(i)之后的i其实就是x的父节点。可以归纳出,下一次的i就是上一次i的父节点。因此就是对每个节点维护的和加c就行了。终止条件为i<=n应该很好理解。

查询操作模版

假设查询区间1~x的和:

1
2
3
4
5
int sum(int x) { // 求下标1~x对应的元素的前缀和
int res = 0; // 前缀和
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}

查询 $[L, R]$ 的和实际上就是查询c[R]-c[L-1]

例题

「题目描述」:

在完成了分配任务之后,西部 314 来到了楼兰古城的西部。

相传很久以前这片土地上(比楼兰古城还早)生活着两个部落,一个部落崇拜尖刀 $(V)$ ,一个部落崇拜铁锹 $(∧)$ ,他们分别用 $V$ 和 $∧$ 的形状来代表各自部落的图腾。

西部 314 在楼兰古城的下面发现了一幅巨大的壁画,壁画上被标记出了 $n$ 个点,经测量发现这$n $个点的水平位置和竖直位置是两两不同的。

西部 314 认为这幅壁画所包含的信息与这$n$个点的相对位置有关,因此不妨设坐标分别为 $(1,y_1),(2,y_2),\cdots,(n,y_n)$ ,其中 $y_1$ ~ $y_n$ 是 $1$ 到 $n$ 的一个排列。

西部 314 打算研究这幅壁画中包含着多少个图腾。

如果三个点 $(i,y_i), (j, y_j ) , ( k , y_k )$ 满足 $1 \le i < j < k \le n$且 $y_i > y_j , y_j < y_k$,则称这三个点构成$V$图腾;

如三个点 $( i , y_i ) , ( j , y_j ) , ( k , y_k )$ 满足 $1 \le i < j < k \le n$,则称这三个点构成$∧$图腾;

西部 314 想知道,这 $n$ 个点中两个部落图腾的数目。

因此,你需要编写一个程序来求出 $V$ 的个数和 $∧$ 的个数。

「输入格式」:

第一行一个数 $n$ 。

第二行是 $n$ 个数,分别代表 $y_1 , y_2 , \cdots , y_n$

「输出格式」:

两个数,中间用空格隔开,依次为 $V$ 的个数和 $∧$ 的个数。

「数据范围」:

对于所有数据,$n \le 200000$,且输出答案不会超过int64。
$y_1$ ~ $y_n$ 是 $1$ 到 $n$ ​的一个排列。

「样例输入」:

1
2
5
1 5 3 2 4

「样例输出」:

1
3 4

「提示」:

根据样例画图:

image-20240305210831253

根据右边的图例,相同颜色的线条构成一个 $V$ 或 $∧$ 。

「思路」:

注意思考如何想到了可以用树状数组的,这是这类题的难点。

对于数列中的每一个数字(假设是 $k$ ),计算它前面比它大的数字的数量 $prevk$,后面比它大的数字的数量 $postv_k$,即可求出V的数量为 $prev_k \times postv_k$,注意,还要对每一个数进行累加,即 $\sum\limits{i=1}^n prev_i \times postv_i$。同理,也可以利用前面比它小的数字的数量 $pren_k$ 和后面比它大的数字的数量 $postn_k$ ,求出 $∧$ 的数量。

所以先从左往右求一遍,维护一个 $prev[k]$ ,表示 $1$ ~ $k-1$ 区间中比A[k]小的数的个数。再从右往左求一遍,维护一个 $postv[k]$ ,表示 $k+1$ ~ $n$ ​中比A[k]小的数的个数

这里查询维护 $prev$ 和 $postv$ 有两个操作:

  1. 遍历到 $y_k$ ,往树状数组中插入 $y_k$
  2. 求区间的和

所以想到了可以用树状数组来维护。

这里为了加速,用great[k]表示k左边比 $y_k$ 大的数的个数(区间 $(y_k, n]$ 的前缀和),low[k]表示k左边比 $y_k$ 小的数的个数(区间$(1, y_k-1]$​的前缀和)。第一次从左往右遍历时构造great[k]low[k]。但是第二次从右往左遍历时,直接相乘,不再构造另外两个数组。(上面的解释只是为了更加直观,因此用的是四个数组)。

注意:初始时可以认为树状数组中下标为k的元素的值为1(因为是维护个数)。

「实现代码」:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N = 200010;

int n, a[N], tr[N]; // a为原数组,tr为树状数组
int great[N], low[N];

int lowbit(int x) {
return x & -x;
}

void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}

int sum(int x) {
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i];

// 从左往右遍历一次
for (int i = 1; i <= n; ++i) {
int y = a[i];
great[i] = sum(n) - sum(y);
low[i] = sum(y-1);
add(y, 1);
}

memset(tr, 0, sizeof(tr)); // 这里要维护新的树状数组了
ll res1 = 0, res2 = 0; // res1表示V的数量,res2表示^的数量
for (int i = n; i; --i) {
int y = a[i];
res1 += great[i] * (ll)(sum(n)-sum(y));
res2 += low[i] * (ll)(sum(y-1));
add(y, 1);
}

cout << res1 << " " << res2 << endl;

return 0;
}

扩展

树状数组+差分思想

「题目描述」:

给定长度为 $N$ 的数列 $A$ ,然后输入 $M$ 行操作指令。

第一类指令形如”C l r d”,表示把数列中第l~r个数都相加d。

第二类指令心如”Q x”,表示询问数列中第x个数的值。

对于每个询问,输出一个整数表示答案。

「输入格式」:

第一行包含两个整数 $N$ 和 $M$ 。

第二行包含 $N$ 个整数 $A[i]$

接下来 $M$ 条指令,每条指令的格式如题目描述所示。

「输出格式」:

对于每个询问,输出一个整数表示答案。每个答案占一行。

「数据范围」:

$1 \le N, M \le 10^5$,

$|d_| \le 10000$​,

$|A[i]| \le 10^9$

「样例输入」:

1
2
3
4
5
6
7
10 5
1 2 3 4 5 6 7 8 9 10
Q 4
Q 1
Q 2
C 1 6 3
Q 2

「样例输出」:

1
2
3
4
4
1
2
5

「思路」:

这道题感觉和树状数组反着来的。树状数组:单点加,求区间前缀和。这道题:区间加,求单点值。

看到区间加(并且加的值相同),应该就能想到差分(例如给区间[l, r]加2:d[l]+=2, d[r+1] -= 2)。而求点a[x]的元素值,其实就是求d[1]+d[2]+…+d[x],即x的前缀和。

那么到这里其实就思路清晰了。我们只需要读入原数组,根据原数组构成差分数组,然后在树状数组中维护差分数组的值,每一次修改区间,就在树状数组中修改两个值即可。最终要求第x个数的值就是求树状数组中1~x个数的前缀和。

如果不会差分可以参看文章:前缀和与差分

「实现代码」:

注意:查询数据读入时的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <iostream>

using namespace std;

typedef long long ll;

const int N = 1e5+5;
int n, m, a[N];
ll tr[N];

int lowbit(int x) {
return x & -x;
}

void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}

ll sum(int x) {
ll res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}

int main() {
cin >> n >> m;

for (int i = 1; i <= n; ++i) cin >> a[i];

for (int i = 1; i <= n; ++i)
add(i, a[i] - a[i-1]);

while (m -- ) {
char op[2];
int l, r, x;
cin >> op >> l;
if (*op == 'C') {
cin >> r >> x;
add(l, x), add(r+1, -x);
} else {
ll res = sum(l);
cout << res << endl;
}
}

return 0;
}

树状数组+差分思想+公式

「题目描述」:

给定长度为 $N$ 的数列 $A$ ,然后输入 $M$ 行操作指令。

第一类指令形如”C l r d”,表示把数列中第l~r个数都相加d。

第二类指令心如”Q l r”,表示询问数列中第l~r个数的值。

对于每个询问,输出一个整数表示答案。

「输入格式」:

第一行包含两个整数 $N$ 和 $M$。

第二行包含 $N$ 个整数 $A[i]$

接下来 $M$ 条指令,每条指令的格式如题目描述所示。

「输出格式」:

对于每个询问,输出一个整数表示答案。每个答案占一行。

「数据范围」:

$1 \le N, M \le 10^5$,

$|d_| \le 10000$​,

$|A[i]| \le 10^9$

「样例输入」:

1
2
3
4
5
6
7
10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4

「样例输出」:

1
2
3
4
4
55
9
15

「思路」:

这道题变成了区间加,区间求和。

区间加一定还是用差分来做。

这里的区间求和其实就是求出1~l-1的和,1~r的和,然后后者减去前者。这里要根据差分数组d求原数组第x个数的值的公式为:$ax=\sum\limits{i=1}^x d_i$。

于是得出根据差分数组求前x数的前缀和的公式。

于是就是要求下图中黄色框内数的和:

image-20240306114906550

若补齐为如图形式:

image-20240306115101168

可以求出此时黄色框内的元素之和为:$x \sum\limits{i=1}^x d_i$。$\sum\limits{i=1}^x d_i$其实就是上一题求多前缀和,可以通过树状数组直接得到。这里再加上一行,变为:$(x+1) \sum\limits_{i=1}^x d_i$。

这样可以发现红色补足部分刚好为:$d1+2d_2+3d_3 + \cdots + x*d_x = \sum\limits{i=1}^x i*d_i$,也变成了一个关于d的前缀和。

这样就可以得到最终结果为:$(x+1)\sum\limits{i=1}^x d_i - \sum\limits{i=1}^x id_i$。

所以这道题需要维护两个前缀和:$\sum\limits{i=1}^x d_i$ 和 $\sum\limits{i=1}^x i*d_i$ 。因此用两个树状数组来维护这两个前缀和。

「实现代码」:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <iostream>

using namespace std;

typedef long long ll;

const int N = 1e5+5;
int n, m, a[N];
ll tr1[N], tr2[N]; // tr1维护d[i],tr2维护i*d[i]

int lowbit(int x) {
return x & -x;
}

void add(int x, ll c) {
for (int i = x; i <= n; i += lowbit(i))
tr1[i] += c, tr2[i] += x*c; // 注意不要写成i*c了
}

ll sum(int x) {
ll res1 = 0, res2 = 0;
for (int i = x; i; i -= lowbit(i))
res1 += tr1[i], res2 += tr2[i];

ll res = (x+1)*res1 - res2;
return res;
}

int main() {
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];

for (int i = 1; i <= n; ++i)
add(i, (ll)(a[i]-a[i-1]));

while (m -- ) {
char op[2];
int l, r, d;
cin >> op >> l >> r;
if (*op == 'C') {
cin >> d;
add(l, d), add(r+1, -d);
} else {
ll res = sum(r) - sum(l-1);
cout << res << endl;
}
}

return 0;
}

应用题

迷一样的牛

「题目描述」:

有 $n$ 头奶牛,已知它们的身高为 $1$ ~ $n$ 且各不相同,但不知道每头奶牛的具体身高。

现在这 $n$ 头奶牛站成一列,已知第 $i$ 头牛前面有 $A_i$ 头牛比它低,求每头奶牛的身高。

「输入格式」:

第 $1$ 行:输入整数 $n$ 。

第 $2 \dots n$ 行:每行输入一个整数 $A_i$ ,第i行表示第i头牛前面有 $A_i$ 头牛比它低。
(注意:因为第1头牛前面没有牛,所以并没有将它列出)

「输出格式」:

输出包含 $n$ 行,每行输出一个整数表示牛的身高。

第 $i$ 行输出第 $i$ ​头牛的身高。

「数据范围」:

$1 \le n \le 10^5$

「样例输入」:

1
2
3
4
5
5
1
2
1
0

「样例输出」:

1
2
3
4
5
2
4
5
3
1

「思路」:

因为这道题的牛的身高是1~n的,并且互不相同,因此每个高度只能用一次。

这道题要从最后一头牛往前看。对于最后一头牛来说,如果前面有 $a_n$ 头牛比他矮,那么它的身高一定是 $a_n + 1$ 。那么剩下的身高就是: $1$~$a_n$,$(a_n+2)$ ~ $n$ 了。

对于第 $i$ 头牛如果它前面有 $a_i$ 头牛比他矮,那么它一定在剩下的身高中排 $a_i+1$ 高(注意这里的身高值不一定是 $a_i+ 1$ ​,它只是代表在剩下的数中的排名)。

于是分析出从后往前遍历到每一头的时候要进行两种操作:

  1. 从剩余的牛中找出第 $a_i + 1$ 小的数,作为第 $i$ 头牛的身高
  2. 将找到的数删除

实现这两种操作其实用平衡树会更好,但是这里是练习树状数组,就用树状数组。

用树状数组,就令树状数组的初始 $A_0 = A_1 = \cdots = A_i = \cdots = A_n = 1$ ​,表示该数还可以用1次。

然后找到第 $a_i + 1$ 小的数的时候的操作其实就是求前 $x$ 个数的前缀和,判断 $x$ 取多少时前缀和等于 $a_i+1$​。

删除找到的数其实就是 $A_x -= 1$​。

但是,这里的 $x$ 的取值大概率不是 $a_i + 1$,这里要加速找到第一个 $a_i + 1$ 小的数就可以考虑二分查找了。【因为在 第 $x$ 个数前的数 的前缀和一定小于 前 $x$ 数 的前缀和,第 $x$ ​个数后的数 的前缀和一定大于 前x个数 的前缀和(满足单调性)。】

这里还有一个注意点:考虑:假设当第2位和第4位被选择了,剩下的位置为110011。那么想要找第2高的数的时候,发现前2个,前3个,前4个数的前缀和都是2。我们最终要得到的结果为2,3、4显然不是正确答案,因此我们要找的就是最小的满足条件的位置,即二分查找下界。

「实现代码」:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <iostream>
#include <cstring>

using namespace std;

const int N = 1e5+5;

int n, a[N], tr[N], ans[N];

int lowbit(int x) {
return x & -x;
}

void add(int x, int c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}

int sum(int x) {
int res = 0;
for (int i = x; i ; i -= lowbit(i)) res += tr[i];
return res;
}

int find(int x) {
int l = 1, r = n+1, mid;
while (l < r) { // 左闭右开
mid = (l + r) >> 1;
if (sum(mid) < x) l = mid + 1;
else r = mid;
}
return r;
}

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) add(i, 1);
// 其实可以这样初始化加速:
// for (int i = 1; i <= n; ++i) tr[i] = lowbit(i);

for (int i = 2; i <= n; ++i) cin >> a[i];

for (int i = n; i >= 1; --i) {
int res = find(a[i]+1);
ans[i] = res;
add(res, -1);
}

for (int i = 1; i <= n; ++i) cout << ans[i] << endl;

return 0;
}

最后解释一下为什么初始化树状数组时可以用,这样初始化的时间复杂度为 $O(n)$ (上面那种初始化的时间复杂度为 $O(n*\log n)$ 稍慢一点):

1
for (int i = 1; i <= n; ++i) tr[i] = lowbit(i);

因为树状数组的A[i]都是1,因此树状数组中第i个数对应的前缀和其实就是区间长度,前面讲过了,区间长度其实就是lowbit(i)