树状数组可以当成工具来用,用来计算任意连续区间的和或是最小或最大值,树状数组可以解决的问题,线段树都可以解决,但是线段树可以解决的问题,树状数组不一定能解决,那为什么还用树状数组呢?因为树状数组实现起来比较方便
树状数组用上懒惰标记可以减少递归的次数,懒惰标记实际上就是让子节点暂时处于不更新状态,用到的时候再更新,如本题中的visit[]就是懒惰标记,例如总长度是1-10,我们现在要想更新1-6,(将1-6的值都加3)那么update()会先找1-10,发现不合适,再找他的左右孩子,发现1<5,说明1-6的区间在1-10的左孩子中,同时6>5,1-6也在1-10的右孩子中,这样依次去找1-6在的区间。但是找到1-5的时候,我们发现整个1-5都在1-6中间,也就是说这一段都要更新,那么我们将1-5的sum值更新了,同时用add[rt]+=3记录下来1-5中的数字现在每个都 要加的数字,但是1-5下边还有1-3,4-5,3-3,4-4,5-5,这些我们就可以不用更新,因为这些我们暂时还用不到,假如现在又要将1-5区间的值都加5,那么add[rt]+=5,此时就是8了,但是还是不用更新他的子节点,假如我们现在要用到1-3区间了,我们就可以一次性给1-3区间加上8,而不用先加3,再加5,这样懒惰标记就使得每次的递归都少了好多
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
using namespace std;
const int maxn = 5005;
int n, c[maxn], a[maxn];
int sum[maxn << 2]; //sum线段树维护的区间和,数组的大小是a数组的4倍
int add[maxn << 2]; //懒惰标记数组
//树状数组
int lowbit(int x) {
return x & -x;
}
void add_(int k, int num) {
while(k <= n) {
c[k] += num; //修改自己,然后在一步步想上修改父节点,一直修改到根节点为止
k += lowbit(k); //k+lowbit(k)的值为k的父节点
}
}
//想要任意区间的和,只需要两个区间进行相减,比如要求8~10区间的和,可以Sum(10)-Sum(7)
int Sum(int k) { //求的是c[1] ~ c[k]的和
int ans = 0;
while(k > 0) {
ans += c[k];
k -= lowbit(k);
}
return ans;
}
//线段树,树状数组可以解决的问题,线段树都可以解决,但是线段树可以解决的问题,树状数组不一定能解决
void pushup(int rt) {
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
void build(int l, int r, int rt) { //建树
if( l== r) {
sum[rt] = a[l];
return;
}
int m = (l + r) >> 1; // 二分
build(l, m, rt << 1); //构建左子书
build(m + 1, r, rt << 1 | 1); //构建右子树
pushup(rt); //构建完左右子树后,回溯向上更新他们的父节点
}
void update_p(int k, int val, int l, int r, int rt) { //单点修改 ,这里是该点加val
if(l == r) {
sum[rt] += val; //找到要修改的点后,更新该点,加val ,通过一步步二分,从最上端的根节点找到最低端叶子节点,
return; //也就是要更新的点
}
int m = (l + r) >> 1;
if(k <= m) update_p(k, val, l, m, rt << 1);
else update_p(k, val, m + 1, r, rt << 1 | 1);
pushup(rt); //同样更新完了左右子节点,要想上回溯跟新父节点
}
void pushdown(int rt, int lnum, int rnum) {
if(add[rt]) {
add[rt << 1] += add[rt]; //向下跟新左右子节点的懒惰标记
add[rt << 1 | 1] += add[rt];
sum[rt << 1] += add[rt]; //左右节点各自的sum和是由他的父节点 向下更新得到
sum[rt << 1 | 1] += add[rt];
add[rt] = 0; //用完以后,懒惰标记清零
}
}
//区间修改用了懒惰标记,该代码是L到R区间每个数都加val
void update(int L, int R, int val, int l, int r, int rt) { //区间修改
if(l >= L && r <= R) {
sum[rt] += (r - l + 1) * val; //算一下区间的长度,然后乘以val,加到原区间上
add[rt] += val; //懒惰标记数组加上val
return;
}
int m = (l + r) >> 1;
pushdown(rt, m - l + 1, r - m); //m-l+1代表左子区间的长度, r-m代表右子区间的长度
if(L <= m) update(L, R, val, l, m, rt << 1);
if(R > m) update(L, R, val, m + 1, r, rt << 1 | 1);
pushup(rt);
}
int query(int L, int R, int l, int r, int rt) { //区间查询
if(l >= L && r <= R) {
return sum[rt];
}
int ans = 0;
int m = (l + r) >> 1;
pushdown(rt, m - l + 1, r - m);
if(L <= m) ans += query(L, R, l, m, rt << 1);
if(R > m) ans +=query(L, R, m + 1, r, rt << 1 | 1);
return ans;
}