树状数组
P3374 【模板】树状数组 1 - 区间求值&元素操作
题目描述
如题,已知一个数列,你需要进行下面两种操作:
- 将某一个数加上
x
- 求出某区间每一个数的和
输入格式
第一行包含两个正整数n
,m
分别表示该数列数字的个数和操作的总个数。
第二行包含n
个用空格分隔的整数,其中第i
个数字表示数列第i
项的初始值。
接下来m
行每行包含3
个整数,表示一个操作,具体如下:
1 x k
含义:将第x
个数加上k
2 x y
含义:输出区间[x,y]
内每个数的和
输出格式
输出包含若干行整数,即为所有操作2
的结果。
输入输出样例
输入 | 输出 |
---|---|
5 5 1 5 4 2 3 1 1 3 2 2 5 1 3 -1 1 4 2 2 1 4 | 14 16 |
说明/提示
【数据范围】
对于30%
的数据,1≤n≤8
,1≤m≤10
;
对于70%
的数据,1≤n,m≤1E4
;
对于100%
的数据,1≤n,m≤5×1E5
。
样例说明:
故输出结果14
、16
。
思想
…
c[]数组(树状数组)表示自己所管辖的a[]数组(原数组)的数的和。
比如 从图中可以看出来:
c[4]
=a[1]
+a[2]
+a[3]
+a[4]
;c[6]
=a[5]
+a[6]
;
然后我们看一下4和6的二进制数:
- 4 -> 100
- 6 -> 110
我们发现:二进制的最低位的1转成十进制后(我们设转成十进制后是x)就是这个数组所管辖的个数:即
c[4]
管4个,c[6]
管2个;如果我们仔细仔细看一下这个图,看一下每个c[ ]
管辖的特点,我们发现每个c[ ]
管辖的数都是连续的。从另一个角度来说,每个
c[ ]
管从自己开始往前x个连续的数。…
很容易得到一个结论:节点编号+区间长度=父亲编号
对于树状数组中的lowbit
:
lowbit
是一种返回最后一个为1的bit位的操作。
当想要查询一个
sum(n)
(求a[n]
的和),可以依据如下算法即可:
- 令
sum
= 0;- 假如
n
<= 0,算法结束,返回sum
值。否则sum
=sum
+c[n]
;- 令
n
=n
–lowbit(n)
,转步骤2
。
n
=n
–lowbit(n)
这一步实际上等价于将n
的二进制的最后一个1
减去。而n
的二进制里最多有log(n)
个1
,所以查询效率是log(n)
的。…
那么修改呢,修改一个节点,必须修改其所有祖先,最坏情况下为修改第一个元素,最多有
log(n)
的祖先。修改算法如下(给某个结点
i
加上x
):
- 当
i
<=n
时,c[i]
=c[i]
+x
,i
=i
+lowbit(i)
循环该过程直到i
>n
i
=i
+lowbit(i)
这个过程实际上也只是一个把末尾1补为0的过程。
代码
#include <cstdio>
typedef int datatype;
struct bintree_base
{
//树状数组指针
datatype *a;
//元素数量
int n;
//初始化
explicit bintree_base(const int &_n){
n = _n;
a = new datatype[n + 1]();
}
~bintree_base(){
delete[] a;
}
datatype lowbit(const datatype &num){
return num & (-num);
}
//求从1到index的和
datatype accumulate(int index){
datatype ans = 0;
while (index)
ans += a[index], index -= lowbit(index);
return ans;
}
//求区间和
datatype accumulate(const int &begin, const int &end){
return accumulate(end) - accumulate(begin - 1);
}
//在index位置上+value并更新到树上
void add(int index, const datatype &value){
while (index <= n)
a[index] += value, index += lowbit(index);
}
};
int main(){
int n, m, tmp1, tmp2, tmp3;
scanf("%d%d", &n, &m);
bintree_base bt(n);
for (int i = 1; i <= n; ++i){
scanf("%d", &tmp1);
bt.add(i, tmp1);
}
for (int i = 0; i < m; ++i){
scanf("%d%d%d", &tmp1, &tmp2, &tmp3);
if (tmp1 == 1)
bt.add(tmp2, tmp3);
else
printf("%d\n", bt.accumulate(tmp2, tmp3));
}
return 0;
}