以前都是用线段树做的,感觉比较复杂。今天做了一下午的逆序数,都没做出来好伤心,结果学长而我说那是一道简单的树状数组的题目。今天继续着树状数组,选了到简单而求熟悉的排兵布阵。
1、概述
树状数组(binary indexed tree),是一种设计新颖的数组结构,它能够高效地获取数组中连续n个数的和。概括说,树状数组通常用于解决以下问题:数组{a}中的元素可能不断地被修改,怎样才能快速地获取连续几个数的和?
2、树状数组基本操作
传统数组(共n个元素)的元素修改和连续元素求和的复杂度分别为O(1)和O(n)。树状数组通过将线性结构转换成伪树状结构(线性结构只能逐个扫描元素,而树状结构可以实现跳跃式扫描),使得修改和求和复杂度均为O(lgn),大大提高了整体效率。
给定序列(数列)A,我们设一个数组C满足
C[i] = A[i–2^k+ 1] + … + A[i]
其中,k为i在二进制下末尾0的个数,i从1开始算!
则我们称C为树状数组。
下面的问题是,给定i,如何求2^k?
答案很简单:2^k=i&(i^(i-1)) ,也就是i&(-i)
下面进行解释:
以i=6为例(注意:a_x表示数字a是x进制表示形式):
(i)_10 = (0110)_2
(i-1)_10=(0101)_2
i xor (i-1) =(0011)_2
i and (i xor (i-1)) =(0010)_2
2^k = 2
C[6] = C[6-2+1]+…+A[6]=A[5]+A[6]
数组C的具体含义如下图所示:
当我们修改A[i]的值时,可以从C[i]往根节点一路上溯,调整这条路上的所有C[]即可,这个操作的复杂度在最坏情况下就是树的高度即O(logn)。另外,对于求数列的前n项和,只需找到n以前的所有最大子树,把其根节点的C加起来即可。不难发现,这些子树的数目是n在二进制时1的个数,或者说是把n展开成2的幂方和时的项数,因此,求和操作的复杂度也是O(logn)。
树状数组能快速求任意区间的和:A[i] + A[i+1] + … + A[j],设sum(k) = A[1]+A[2]+…+A[k],则A[i] + A[i+1] + … + A[j] = sum(j)-sum(i-1)。
下面给出树状数组的C语言实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
//求2^k
int lowbit( int t)
{
return t & ( t ^ ( t - 1 ) );
}
//求前n项和
int sum( int end)
{
int sum = 0;
while (end > 0)
{
sum += in[end];
end -= lowbit(end);
}
return sum;
}
//增加某个元素的大小
void plus( int pos, int num)
{
while (pos <= n)
{
in[pos] += num;
pos += lowbit(pos);
}
}
|
3、扩展——二维树状数组
一维树状数组很容易扩展到二维,二维树状数组如下所示:
C[x][y] = sum(A[i][j])
其中,x-lowbit[x]+1 <= i<=x且y-lowbit[y]+1 <= j <=y
4、应用
(1) 一维树状数组:
参见:http://hi.baidu.com/lilu03555/blog/item/4118f04429739580b3b7dc74.html
(2) 二维树状数组:
一个由数字构成的大矩阵,能进行两种操作
1) 对矩阵里的某个数加上一个整数(可正可负)
2) 查询某个子矩阵里所有数字的和
要求对每次查询,输出结果
5、总结
树状数组最初是在设计压缩算法时发现的(见参考资料1),现在也会经常用语维护子序列和。它与线段树(具体见:数据结构之线段树)比较在思想上类似,比线段树节省空间且编程复杂度低,但使用范围比线段树小(如查询每个区间最小值问题)。
代码:
#include<iostream>
#include<string>
using namespace std;
int ss[50005];
int n;
int lowbit(int a)
{
return a&(-a);
}
void update(int pos,int num)
{
while(pos<=n)
{
ss[pos]+=num;
pos+=lowbit(pos);
}
}
int sum(int a)
{
int total=0;
while(a>0)
{
total+=ss[a];
a-=lowbit(a);
}
return total;
}
int main()
{
int t,i,j,x;
scanf("%d",&t);
for(i=1;i<=t;i++)
{
memset(ss,0,sizeof(ss));
printf("Case %d:\n",i);
scanf("%d",&n);
for(j=1;j<=n;j++)
{
scanf("%d",&x);
update(j,x);
}
string s;
while(cin>>s)
{
int a,b;
if(s[0]=='Q')
{
scanf("%d%d",&a,&b);
int x=sum(b)-sum(a-1);
printf("%d\n",x);
}
if(s[0]=='A')
{
scanf("%d%d",&a,&b);
update(a,b);
}
if(s[0]=='S')
{
scanf("%d%d",&a,&b);
update(a,-b);
}
if(s[0]=='E') break;
}
}
return 0;
}