首先我们得知道一个问题,那就是线段树得作用并不只是用来存储线段的,也可以存储点的值等等。对于静态的线段树,空间上需要的数组有:当前结点的数据值,左儿子编号,右儿子编号,至少这么三个数组。而在时间上虽然是NlogN的复杂度,但是系数很大。实现起来的时候编程复杂度大,空间复杂度大,时间效率也不是很理想。针对于这些缺点,树状数组便有了自己的优势。
下面从一个例题开始:
题目大意:
数列操作。给定一个初始值都为0的序列,动态地修改一些位置上的数字,加上一个数,减去一个数,或者乘上一个数,然后动态地提出问题,问题的形式是求出一段区间数字的和。
1.用线段树可以这样解:
若要维护的序列范围是0..5,先构造下面的一棵线段树:
可以看出,这棵树的构造用二分便可以实现,复杂度是2*N。
修改一个位置上数字的值,就是修改一个叶子结点的值,而当程序由叶子结点返回根节点的同时顺便修改掉路径上的结点的a数组的值。对于询问的回答,可以直接查找i..j范围内的值,遇到分叉时就兵分两路,最后在合起来。也可以先找出0..i-1的值和0..j的值,两个值减一减就行了。后者的实际操作次数比前者小一些。
这样修改与维护的复杂度是logN。询问的复杂度也是logN,对于M次询问,复杂度是MlogN。
->缺点:线段树的编程复杂度大,空间复杂度大,时间效率也不高。
2.树状数组的介绍。
树状数组是一个查询和修改复杂度都为log(n)的数据结构,可以很高效的进行区间统计。在思想上类似于线段树,比线段树节省空间,编程复杂度比线段树低,但适用范围比线段树小。
来观察这个图[数组下标是从1开始]:
令这棵树的结点编号为C1,C2...Cn。令每个结点的值为这棵树的值的总和,那么容易发现:
C1 = A1
C2 = A1 + A2
C3 = A3
C4 = A1 + A2 + A3 + A4
C5 = A5
C6 = A5 + A6
C7 = A7
C8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8
...
C16 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8 + A9 + A10 + A11 + A12 + A13 + A14 + A15 + A16
这里有一个有趣的性质:
设节点编号为x,那么这个节点管辖的区间为2^k(其中k为x二进制末尾0的个数)个元素。因为这个区间最后一个元素必然为Ax。(管辖区间就是记录的区间个数,1,3,5,7,9为1。2,6为2。4为4。8为8。)
所以很明显:Cn = A[n – 2^k + 1] + ... + A[n]
算这个2^k有一个快捷的办法,定义一个函数如下即可:
int lowbit(int x){
return x & (x^(x–1));
}
想要查询一个SUM(n),可以依据如下算法即可(求a[0]~a[n]的和):
step1: 令sum = 0,转第二步;
step2: 假如n <= 0,算法结束,返回sum值,否则sum = sum + Cn,转第三步;
step3: 令 n = n – lowbit(n),转第二步。
可以看出,这个算法就是将这一个个区间的和全部加起来,为什么是效率是log(n)的呢?以下给出证明:
n = n – lowbit(n)这一步实际上等价于将n的二进制的最后一个1减去。而n的二进制里最多有log(n)个1,所以查询效率是log(n)的。
int sum(int i){
int ans = 0;
while(i > 0){
ans += ar[i];
i -= lowbit(i);
}
return ans;
}
那么修改UPDATE()呢,修改一个节点,必须修改其所有祖先,最坏情况下为修改第一个元素,最多有log(n)的祖先。
所以修改算法如下(给某个结点i加上x):
step1: 当i > n时,算法结束,否则转第二步;
step2: Ci = Ci + x, i = i + lowbit(i)转第一步。
i = i +lowbit(i)这个过程实际上也只是一个把末尾1补为0的过程。
void add(int i, int w){
while(i <= n){
ar[i] += w;
i += lowbit(i);
}
}
扩展:二维树状数组(要学会运用)。
对ar[1][1]增加的效果:
void add(int i, int j, int w){
int tmpj;
while(i <= row){
tmpj = j;
while(tmpj <= col){
ar[i][tmpj] += w;
tmpj += lowbit(tmpj);
}
i += lowbit(i);
}
}
int sum(int i, int j){
int tmpj, ans = 0;
while(i > 0){
tmpj = j;
while(tmpj > 0){
ans += ar[i][tmpj];
tmpj -= lowbit(tmpj);
}
i -= lowbit(i);
}
return ans;
}