树状数组是一个查询和修改复杂度都为log(n)的数据结构。主要用于查询任意两位置之间的所有元素之和,每次可以修改某一处元素的值。
以下是树状数组的存储方式(图片来源于互联网)
可以看出:
C[1]=A[1]
C[2]=A[1]+A[2]
C[3]=A[3]
C[4]=A[1]+A[2]+A[3]+A[4]
C[5]=A[5]
C[6]=A[5]+A[6]
C[7]=A[7]
C[8]=A[1]+A[2]+A[3]+A[4]+A[5]+A[6]+A[7]+A[8]
……
严格的定义为:C[n] = A[n-2^k+1] + …… + A[n],其中,k为n的二进制表示中最低位的1。即:C[n]为A[n]在内的前2^k个数之和,C[n]的辖域为2 ^ k。
则查询前x个数之和可以如下进行:首先确定C[x]的辖域为low_bit(x),则累加完这些元素之和C[x]后,继续求解1~C[x]-low_bit(x)这个区间的元素之和,x = 0时停止迭代。函数如下:
int getsum(int idx){
int sum = 0;
for(int i = idx; i > 0; i -= lowbit(i))
sum += c[i];
return sum;
}
其中low_bit求解如下:
int lowbit(int x){
return x & (-x);
}
结合负数的补码表示方式:取反加一,若最低位为第i位,则取反后第i位为0,第0~i - 1位均为1,加一后使得第i位为1,则再进行&操作,得到的就是2 ^ i。
修改时,对某个元素进行修改,则要沿着该元素所在路径一直向父节点修改,复杂度为O(logn)。关键是父节点的确定。其实树状数组可以理解为减少了一半节点的线段树,如下图所示(图片来源于http://www.cppblog.com/Ylemzy/articles/98322.html):
其中的空白节点,即是树状数组相对线段树节省的节点,空白节点可以理解为该节点的兄弟节点,这两个子树拥有同样数目的元素(辖域相同),则父节点的辖域为左右子树的元素个数之和,父节点的下标为c[x] + low_bit(x)。则修改的代码如下:
void update(int idx, int delta){
for(int i = idx; i <= n; i += low_bit(i))
c[i] += delta;
}
例题:hdu1166 http://acm.hdu.edu.cn/showproblem.php?pid=1166 代码如下:
#include <cstdio>
#include <cstring>
using namespace std;
#define N 50005
int n, a[N], c[N] ,sum[N];
char opt[10];
int low_bit(int x){
return x & (-x); //负数取反加1,最低位的1处被set
}
int get_sum(int k){
int res = 0;
for(int i = k; i > 0; i -= low_bit(i))
res += c[i];
return res;
}
void update(int idx, int delta){
for(int i = idx; i <= n; i += low_bit(i))
c[i] += delta;
}
int main(){
int tc, ca = 0;
scanf("%d", &tc);
while(tc --){
scanf("%d", &n);
sum[0] = 0;
for(int i = 1; i <= n; ++i){
scanf("%d", &a[i]);
sum[i] = sum[i - 1] + a[i]; //也可直接在此处一个一个update,不过需要将c清零,且复杂度为nlogn
}
for(int i = 1; i <= n; ++i)
c[i] = sum[i] - sum[i - low_bit(i)];
getchar();
printf("Case %d:\n", ++ca);
while(scanf("%s", opt) == 1 && strcmp(opt, "End")){
int op1, op2;
scanf("%d %d", &op1, &op2);
if(strcmp(opt, "Query") == 0){
printf("%d\n", get_sum(op2) - get_sum(op1 - 1));
}
else{
int delta = strcmp(opt, "Add") == 0 ? op2 : -op2;
update(op1, delta);
}
}
}
return 0;
}
除此之外,当元素值在某一范围之内时,可以用来求第k小/大的数。类似哈希的思想,a[i]中存储的是值为a[i]的元素个数,则sum(x)为不大于x的元素个数。
方法①:常规方法,使用二分查找第一个sum不小于k的sum(x),即sum(x - 1) <k, sum(x) >= k,但速度上比方法②略慢一些,因为每次二分时都调用了get_sum函数;
方法②:二进制增量法不断set x的每一位,使其向目标值ans靠近,从最高位开始,试探该位置一后不大于x的元素个数是否小于k,若小于k则该位可以置一,并记录当前范围内元素的个数(此处的处理节约了get_sum的计算开销);注意c[x]代表的是处于区间[x - low_bit(x) + 1, x]的元素个数。代码如下:
int getkth(int a, int k){
int ans = 0, cnt = -getsum(a);
for(int i = 20; i >= 0; i --){
ans += 1 << i;
if(ans >= MAXN || cnt + c[ans] >= k)
ans -= 1 << i;
else
cnt += c[ans]; //将新扩展的区间中元素个数累加到当前总个数中
}
return ans + 1;
}
上述代码中,在每轮的循环中,由于变量i是递减的,若点ans可行,则c[ans]的辖域为(1 << i),代表的是处于区间[ans - 1 << i + 1, ans]之间的元素个数,而ans' = ans - 1 << i,即是上一轮循环确定的范围,cnt为[1, ans']区间的元素个数。
例题:hdu2852 http://acm.hdu.edu.cn/showproblem.php?pid=2852 代码如下:
#include <cstdio>
#include <cstring>
using namespace std;
#define MAXN 100000
int c[MAXN];
int lowbit(int x){
return x & (-x);
}
void add(int idx, int delta){
for(int i = idx; i < MAXN; i += lowbit(i))
c[i] += delta;
}
int getsum(int idx){
int sum = 0;
for(int i = idx; i > 0; i -= lowbit(i))
sum += c[i];
return sum;
}
int getkth(int a, int k){
int ans = 0, cnt = -getsum(a);
for(int i = 20; i >= 0; i --){
ans += 1 << i;
if(ans >= MAXN || cnt + c[ans] >= k)
ans -= 1 << i;
else
cnt += c[ans];
}
return ans + 1;
}
int main(){
int m;
while(scanf("%d", &m) == 1){
int ops;
memset(c, 0, sizeof(c));
for(int i = 0; i < m; i ++){
scanf("%d", &ops);
if(ops != 2){
int val;
scanf("%d", &val);
if(ops == 0)
add(val, 1);
else{
if(getsum(val - 1) != getsum(val))
add(val, -1);
else
printf("No Elment!\n");
}
}
else{
int a, k;
scanf("%d %d", &a, &k);
int ans = getkth(a, k);
if(ans == MAXN)
printf("Not Find!\n");
else
printf("%d\n", ans);
}
}
}
}