[转载]树状数组学习!
http://www.cnblogs.com/zichi/p/4806998.html
最近在学习位运算,正好把树状数组总结下,也算是能正式给data structure
建个分类。
那么,树状数组到底有什么用呢?诚然,一样没什么卵用的东西我们学它干嘛。
下面举个树状数组的经典应用:区间求和。
假设我们有如下数组(数组元素从 index=1
开始):
var a = [X, 1, 2, 3, 4, 5, 6, 7, 8, 9];
我们设定两种操作,modify(index, x)
表示将 a[index]
元素加上x, query(n, m)
表示求解 a[n] ~ a[m]
之间元素的和。如果不了解树状数组(当然假设更不了解线段树等其他数据结构),你可能会很容易地写下如下代码:
-
var a = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9];
-
-
function query(n, m) {
-
var sum =
0;
-
for (
var i = n; i <= m; i++)
-
sum += a[i];
-
return sum;
-
}
-
-
function modify(index, x) {
-
a[index] += x;
-
}
Ok,复杂度为O(1)的删改和复杂度为O(n)的查询。如果数据量很大,这样反复的查询是相当耗时的。我们退一步想,如果只有 query(n, m)
这个操作,很容易想到用sum数组预处理前n项的和,然后用 sum[m] - sum[n-1]
获得答案。但是如果要修改 a[index]
的值,因为该项影响所有index之后的sum数组元素,所以如果这样做复杂度变为O(1)的查询和O(n)的删改,并没有什么卵用。
但是这个思路是美好的,我们可以用一个sum数组保存一段特定的区间段的值。假设我们有 a[1] ~ a[9]
9个元素,我们根据一个特定的规则:
-
sum[
1] = a[
1];
-
sum[
2] = a[
1] + a[
2];
-
sum[
3] = a[
3];
-
sum[
4] = a[
1] + a[
2] + a[
3] + a[
4];
-
sum[
5] = a[
5];
-
sum[
6] = a[
5] + a[
6];
-
sum[
7] = a[
7];
-
sum[
8] = a[
1] + a[
2] + a[
3] + a[
4] + a[
5] + a[
6] + a[
7] + a[
8];
-
sum[
9] = a[
9];
如果要求 a[1] ~ a[9]
的和,即为 sum[9] + sum[8]
,如果要求 a[1] ~ a[7]
的和,即为 sum[7] + sum[6] + sum[4]
,如果要改变 a[1]
的值,改变sum数组中和 a[1]
有关的项即可(即 sum[1]
sum[2]
sum[4]
sum[8]
)。 这就是树状数组!实现了O(logn)的查询和删改。但是如何将a数组和sum数组联系起来?
来观察这个图:
令这棵树的结点编号为C1,C2...Cn。令每个结点的值为这棵树的值的总和,那么容易发现(如上所说):
-
C1 = A1
-
C2 = A1 + A2
-
C3 = A3
-
C4 = A1 + A2 + A3 + A4
-
C5 = A5
-
C6 = A5 + A6
-
C7 = A7
-
C8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8
http://www.cnblogs.com/huangxincheng/archive/2012/12/05/2802858.html
从图中我们可以看到S[]的分布变成了一颗树,有意思吧,下面我们看看S[i]中到底存放着哪些a[i]的值。
S[1]=a[1];
S[2]=a[1]+a[2];
S[3]=a[3];
S[4]=a[1]+a[2]+a[3]+a[4];
S[5]=a[5];
S[6]=a[5]+a[6];
S[7]=a[7];
S[8]=a[1]+a[2]+a[3]+a[4]+a[5]+a[6]+a[7]+a[8];
之所以采用这样的分布方式,是因为我们使用的是这样的一个公式:S[i]=a[i-2k+1]+....+a[i]。
其中:2k 中的k表示当前S[i]在树中的层数,它的值就是i的二进制中末尾连续0的个数,2k也就是表示S[i]中包含了哪些a[],
举个例子: i=610=01102 ;可以发现末尾连续的0有一个,即k=1,则说明S[6]是在树中的第二层,并且S[6]中有21项,随后我们求出了起始项:
a[6-21+1]=a[5],但是在编码中求出k的值还是有点麻烦的,所以我们采用更灵巧的Lowbit技术,即:2k=i&-i 。
则:S[6]=a[6-21+1]=a[6-(6&-6)+1]=a[5]+a[6]。
二:代码
-
class NumArray {
-
public:
-
NumArray(
vector<
int> &nums) {
-
num.resize(nums.size() +
1);
-
bit.resize(nums.size() +
1);
-
for (
int i =
0; i < nums.size(); ++i) {
-
update(i, nums[i]);
-
}
-
}
-
void update(int i, int val) {
-
int diff = val - num[i +
1];
-
for (
int j = i +
1; j < num.size(); j += (j&-j)) {
-
bit[j] += diff;
-
}
-
num[i +
1] = val;
-
}
-
int sumRange(int i, int j) {
-
return getSum(j +
1) - getSum(i);
-
}
-
int getSum(int i) {
-
int res =
0;
-
for (
int j = i; j >
0; j -= (j&-j)) {
-
res += bit[j];
-
}
-
return res;
-
}
-
-
private:
-
vector<
int> num;
-
vector<
int> bit;
-
};
1:神奇的Lowbit函数
1 #region 当前的sum数列的起始下标 2 /// <summary> 3 /// 当前的sum数列的起始下标 4 /// </summary> 5 /// <param name="i"></param> 6 /// <returns></returns> 7 public static int Lowbit(int i) 8 { 9 return i & -i; 10 } 11 #endregion
2:求前n项和
比如上图中,如何求Sum(6),很显然Sum(6)=S4+S6,那么如何寻找S4呢?即找到6以前的所有最大子树,很显然这个求和的复杂度为logN。
1 #region 求前n项和 2 /// <summary> 3 /// 求前n项和 4 /// </summary> 5 /// <param name="x"></param> 6 /// <returns></returns> 7 public static int Sum(int x) 8 { 9 int ans = 0; 10 11 var i = x; 12 13 while (i > 0) 14 { 15 ans += sumArray[i - 1]; 16 17 //当前项的最大子树 18 i -= Lowbit(i); 19 } 20 21 return ans; 22 } 23 #endregion
3:修改
如上图中,如果我修改了a[5]的值,那么包含a[5]的S[5],S[6],S[8]的区间值都需要同步修改,我们看到只要沿着S[5]一直回溯到根即可,
同样它的时间复杂度也为logN。
1 public static void Modify(int x, int newValue) 2 { 3 //拿出原数组的值 4 var oldValue = arr[x]; 5 6 for (int i = x; i < arr.Length; i += Lowbit(i + 1)) 7 { 8 //减去老值,换一个新值 9 sumArray[i] = sumArray[i] - oldValue + newValue; 10 } 11 }
最后上总的代码:
这里有一个有趣的性质:设节点编号为x,那么这个节点管辖的区间为 2^k
(其中k为x二进制末尾0的个数)个元素。因为这个区间最后一个元素必然为Ax,所以很明显:Cn = A(n – 2^k + 1) + ... + An,算这个2^k有一个快捷的办法,定义一个函数如下即可(求解2^k即求二进制码右边第一位1的值):
-
int lowbit(int x) {
-
return x & (-x);
-
}
当想要查询一个SUM(n)(求a[1]~a[n]的和),可以依据如下算法即可:
- 令sum = 0,转第二步;
- 假如n <= 0,算法结束,返回sum值,否则sum = sum + Cn,转第三步;
- 令n = n – lowbit(n),转第二步。
可以看出,这个算法就是将这一个个区间的和全部加起来。
那么修改呢,修改一个节点,必须修改其所有祖先,最坏情况下为修改第一个元素,最多有log(n)的祖先。所以修改算法如下(给某个结点i加上x):
- 当i > n时,算法结束,否则转第二步;
- Ci = Ci + x, i = i + lowbit(i)转第一步。i = i + lowbit(i)这个过程实际上也只是一个把末尾1补为0的过程。 对于数组求和来说树状数组简直太快了!
关于这部分的代码,将在下文树状数组的具体三大应用中给出。
关于树状数组,有一点需要注意,为了方便,树状数组的a数组基本都是从 index=1
开始的。
下文中楼主会分析下树状数组的三大应用场景:改点求段,改段求点,改段求段。
前文我们探讨了树状数组的原理。树状数组就是一种数据结构,它天生用来维护数组的前缀和,从而可以快速求得某一个区间的和,并支持对元素的值进行修改。但是树状数组并非只有这一种功能,变形后它还能衍生出两个功能,本文我们就来分别讨论下树状数组这三大功能。
永远要记住,基本的树状数组维护的是数组的前缀和,所有的区间求值都可以转化成用 sum[m]-sum[n-1]
来解,这点无论是在改点还是接下来要说的改段中都非常重要。
改点求段
这也是树状数组的基本应用。我们可以来看一下这道题 敌兵布阵。
如果看了前文 【前端也要学点数据结构】 神奇的树状数组,解法也就呼之欲出了,直接给出代码:
-
#include<iostream>
-
#include<cstdio>
-
#include<cstring>
-
#include<string>
-
using
namespace
std;
-
#define N 50005
-
int lowbit(int x) {
return x & (-x); }
-
int sum[N], cnt;
-
-
void update(int index, int val) {
-
for (
int i = index; i <= cnt; i += lowbit(i))
-
sum[i] += val;
-
}
-
-
int getSum(int index) {
-
int ans =
0;
-
for (
int i = index; i; i -= lowbit(i))
-
ans += sum[i];
-
return ans;
-
}
-
-
int main() {
-
string str;
-
int n, m, t, tmp, cas =
1;
-
scanf(
"%d", &t);
-
while (t--) {
-
memset(sum,
0,
sizeof(sum));
-
scanf(
"%d", &cnt);
-
for (
int i =
1; i <= cnt; i++) {
-
scanf(
"%d", &tmp);
-
update(i, tmp);
-
}
-
-
printf(
"Case %d:\n", cas++);
-
-
while (
cin >> str) {
-
if (str ==
"End")
break;
-
scanf(
"%d%d", &n, &m);
-
if (str ==
"Query")
-
printf(
"%d\n", getSum(m) - getSum(n -
1));
-
else
if (str ==
"Add")
-
update(n, m);
-
else update(n, -m);
-
}
-
}
-
return
0;
-
}
改段求点
改段求点和改点求段恰好相反,比如有一个数组 a = [x, 0, 0, 0, 0, 0, 0, 0, 0, 0]
,每次的修改都是一段,比如让 a[1]~a[5]
中每个元素都加上10,让 a[6]~a[9]
中每个元素都减去2,求任意的元素的值。
看例题 Color the ball
跟改点求段不同,这里要转变一个思想。在改点求段中,sum[i]表示Ci节点所管辖的子节点的元素和,而在改段求点中,sum[i]表示Ci所管辖子节点的批量统一增量。
还是看这个经典的图:
比方说,C8管辖A1~A8这8个节点,如果A1~A8每个都染色一次,因为前面说了sum[i]表示i所管辖子节点的统一增量,那么也就是 sum[8]+=1
,A5~A7都染色两次,也就是 sum[6] +=2, sum[7] +=2
。如果要求A1被染色的次数,C8是能管辖到A1的,也就是说sum[8]的值和A1被染色的次数有关,仔细想想,也就是把能管辖到A1的父节点的sum值累积起来即可。两个过程正好和改点求段相反。
完整代码:
-
#include<iostream>
-
#include<cstdio>
-
#include<cstring>
-
#include<string>
-
using
namespace
std;
-
#define N 100005
-
int sum[N], n;
-
int lowbit(int x) {
return x & (-x); }
-
-
void update(int index, int val) {
-
while (index) {
-
sum[index] += val;
-
index -= lowbit(index);
-
}
-
}
-
-
int query(int index) {
-
int ans =
0;
-
while (index <= n) {
-
ans += sum[index];
-
index += lowbit(index);
-
}
-
return ans;
-
}
-
-
int main() {
-
int x, y;
-
while (
scanf(
"%d", &n) && n) {
-
memset(sum,
0,
sizeof(sum));
-
for (
int i =
1; i <= n; i++) {
-
scanf(
"%d%d", &x, &y);
-
update(y,
1);
-
update(x -
1,
-1);
-
}
-
-
for (
int i =
1; i < n; i++)
-
printf(
"%d ", query(i));
-
printf(
"%d\n", query(n));
-
}
-
return
0;
-
}
改段求段
改段求段也有道经典的模板题:A Simple Problem with Integers
我们还是从简单的例子入手,比如有如下数组(a[1]=1,..a[9]=9):
1 2 3 4 5 6 7 8 9 10
假设我们将 a[1]~a[4]
这段增加5,对于我们要求的区间和来说,要么是 [1,2]
这种属于所改段的子区间,要么是 [1,8]
这种属于所改段的父区间(前面说了,所有的区间求值都可以用sum[m]-sum[n-1]来解,所以我们只考虑前缀和),我们分别讨论。
如果所求是类似 [1,8]
这种,我们可以很开心地发现,我们将区间增量(4*5)全部加在 a[4]
这个元素上,对结果并没有什么影响!于是变成了一般的改点求段。
如果所求是类似 [1,2]
这种,我们可以用类似改段求点中染色的思想进行处理。譬如 [1,4]
成段加5,如果我们要计算 [1,2]
的和。我们将 [1,3]
进行“染色”(节点4加上了4*5的权重),因为 [1,3]
在树状数组的划分中可以分为两个区间,[1,2]
和 [3,3]
,所以我们用类似改段求点对这两块区域进行“染色”,染上的次数为5。我们要求的是 [1,2]
的区间和,我们只需找 2
被染色的次数,因为 [1,n]
进行染色。如果m(1<=m<=n)被染色,那么m的右边肯定都被染色了。求出被染色的次数,然后乘上区间宽度,就是整段的和了。
这样我们分别对两种情况进行了处理,更重要的是,这两种情况互不影响! 于是我们简单地把两个结果相加就ok了,而这两个过程,分别正是改点求段和改段求点!
完整代码:
-
#include<iostream>
-
#include<cstdio>
-
#include<cstring>
-
using
namespace
std;
-
#define N 100005
-
#define ll __int64
-
ll b[N], c[N];
-
int n;
-
-
int lowbit(int x) {
-
return x & (-x);
-
}
-
-
void update_backwards(int index, ll val) {
-
for (
int i = index; i <= n; i += lowbit(i))
-
b[i] += val;
-
}
-
-
void update_forward(int index, ll val) {
-
for (
int i = index; i; i -= lowbit(i))
-
c[i] += val;
-
}
-
-
void update(int index, ll val) {
-
update_backwards(index, index * val);
-
update_forward(index -
1, val);
-
}
-
-
ll query_forward(int index) {
-
ll ans =
0;
-
for (
int i = index; i; i -= lowbit(i))
-
ans += b[i];
-
return ans;
-
}
-
-
ll query_backwards(int index) {
-
ll ans =
0;
-
for (
int i = index; i <= n; i += lowbit(i))
-
ans += c[i];
-
return ans;
-
}
-
-
ll query(int index) {
-
return query_forward(index) + query_backwards(index) * index;
-
}
-
-
//---------------- main -------------- //
-
int main() {
-
int t, x, y;
-
ll z;
-
char str[
2];
-
memset(b,
0,
sizeof(b));
-
memset(c,
0,
sizeof(c));
-
scanf(
"%d%d", &n, &t);
-
n +=
1;
-
for (
int i =
1; i < n; i++) {
-
scanf(
"%I64d", &z);
-
x = i +
1, y = i +
1;
-
update(y, z);
-
update(x -
1, -z);
-
}
-
-
while (t--) {
-
scanf(
"%s", str);
-
if (str[
0] ==
'C') {
-
scanf(
"%d%d%I64d", &x, &y, &z);
-
x +=
1, y +=
1;
-
update(y, z);
-
update(x -
1, -z);
-
}
else {
-
scanf(
"%d%d", &x, &y);
-
x +=
1, y +=
1;
-
printf(
"%I64d\n", query(y) - query(x -
1));
-
}
-
}
-
return
0;
-
}
这里有一点需要注意:一般的用数组数组来解的题,都是不用a[0]的,也就是元素是从a[1]~a[n],因为 sum[n~m]=sum[m]-sum[n-1]
,避免 n-1
为负数。而本题中的改段求段中的元素是从 a[2]~a[n+1]
,因为 update()
函数中的子函数 update_forward()
函数中 index-1
不能为负,所以参数 index
最小是1,所以 sum[n-1]
中 n-1
最小是1,所以n最小是2,所以元素下标必须从 2
开始。