树状数组模板
基本原理
基础功能 单点更新,区间查询
- 快速求前缀和 O(log(n))
- 修改某一个数字 O(log(n))
下面图片来自acwing的课程:
原始数据是A数组,用C(x)代表A[x-lowbit(x)+1,x]区间内的元素之和,然后用小矩形代替C(x)画图,C(x)覆盖范围就是他们的求和的范围, 下图中的数组可以形成一棵树,C(x)覆盖区间大小决定层高,红线代表上层可以由下层连线的部分求和得出。每一个点x都有儿子,儿子的数量是x-1的二进制组成中1的个数。
例如16(10000)的儿子是15(1111),14(1110),12(1100),8(1000),是16-1后的lowbit云梦算命
基本模板
lowbit函数:
int lowbit(int x){
return x&(-x);
}
操作分为两个,单点更新,求前缀和。
单点更新,例如在某个数字上面加上a,假设在第3位上加,那么依次更新C3,C4,C8,C16,对应的小矩形可以沿着红线寻找,好像一棵树从叶子到根的的一条路径。
void update(int val,int t){
for(int i=val;i<n+10;i+=lowbit(i)){
c[i]+=t;
}
}
求前缀和,假设求15的前缀和,那么依次加上 C15,C14,C12,C8, 就是活用lowbit的性质,将一个数x看作log(x)个区间累加。
int getsum(int val){
int t=0;
for(int i=val;i>0;i-=lowbit(i)){
t+=c[i];
}
return t;
}
建树
一般建树使用for循环,然后update,时间复杂度O(nlog(n)), 也可以减少到O(n),首先O(n)求数组A的前缀和S数组,然后使用c[x]=S[x]-S[x-lowbit(x)]计算,因为c[x]的意义就是A数组[x-lwobit+1,x]区间的和
区间更新,单点查询–差分
原来是单点加,求区间和。现在利用差分数组完成区间加,求单点和。
目标,对区间A[L,R]统一加上v,对于某一点x求对应的值A[x]。
方法:先利用原始A数组求出对应的差分数组B,B[L]+=v,B[R+1]-=v,求x上的值就是求B[0-x]的前缀和S。
区间更新,区间查询–差分
利用差分数组,可以很容易实现 区间更新,单点查询, 然后增加一重循环累加B的前缀和数组S就可以完成区间查询。但是这样时间复杂度会变高。
S1=B1
S2=B1+B2
S3=B1+B2+B3
…
Sx=B1+B2+…+Bx
S1+S2+…+Sx=(x+1)Sx-(1B1+2B2+3B3+…+x*Bx)
所以可以额外计算一个数组i*Bi,然后计算他的前缀和,然后就可以O(1)计算出(1B1+2B2+3B3+…+xBx),算出Sx的前缀和。
所以需要计算两个前缀和数组(树状数组)Bi和i*Bi
LL c[maxn],c2[maxn]; //c 维护 b的树状数组,才维护 i*b[i]的树状数组
int lowbit(int x){
return x&(-x);
}
void update(int pos,int x){
for(int i=pos;i<n+10;i+=lowbit(i)){
c[i]+=x;
c2[i]+=(LL)pos*x;
}
}
LL getsum(int pos){
LL res=0,res2=0;
for(int i=pos;i;i-=lowbit(i)){
res+=c[i];
res2+=c2[i];
}
return res*(pos+1)-res2;
}
LL getsum(int l,int r){
return getsum(r)-getsum(l-1);
}
离散化+树状数组
例题
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
const int maxn=200010;
int a[maxn],c[maxn],lower[maxn],higher[maxn],n;
int lowbit(int x){
return x&(-x);
}
void update(int val,int t){
for(int i=val;i<n+10;i+=lowbit(i)){
c[i]+=t;
}
}
int getsum(int val){
int ans=0;
for(int i=val;i;i-=lowbit(i)){
ans+=c[i];
}
return ans;
}
int main(void){
cin>>n;
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
lower[i]=getsum(a[i]-1);
higher[i]=i-1-getsum(a[i]);
update(a[i],1);
}
for(int i=1;i<=n;++i) c[i]=0;
LL res1=0,res2=0;
for(int i=n;i>0;--i){
res1+=(LL)(getsum(n)-getsum(a[i]))*higher[i];
res2+=(LL)getsum(a[i]-1)*lower[i];
update(a[i],1);
}
printf("%lld %lld\n",res1,res2);
}
AcWing 242. 一个简单的整数问题
区间更新,单点查询
#include<iostream>
#include<cstdio>
using namespace std;
const int maxn=100010;
int n,m,a[maxn],b[maxn],c[maxn];
char ch;
int lowbit(int x){
return x&(-x);
}
void update(int pos,int x){
for(int i=pos;i<n+10;i+=lowbit(i)){
c[i]+=x;
}
}
int getsum(int pos){
int res=0;
for(int i=pos;i;i-=lowbit(i)){
res+=c[i];
}
return res;
}
int main(void){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
b[i]=a[i]-a[i-1];
update(i,b[i]);
}
for(int i=1;i<=m;++i){
scanf("%s",&ch);
if(ch=='Q'){
int x;
scanf("%d",&x);
printf("%d\n",getsum(x));
}else{
int l,r,d;
scanf("%d%d%d",&l,&r,&d);
update(l,d);
update(r+1,-d);
; }
}
}
AcWing 243. 一个简单的整数问题2
区间更新,区间查询
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
const int maxn=100010;
int n,m,a[maxn],b[maxn];
LL c[maxn],c2[maxn]; //c 维护 b的树状数组,才维护 i*b[i]的树状数组
char ch[10];
int lowbit(int x){
return x&(-x);
}
void update(int pos,int x){
for(int i=pos;i<n+10;i+=lowbit(i)){
c[i]+=x;
c2[i]+=(LL)pos*x;
}
}
LL getsum(int pos){
LL res=0,res2=0;
for(int i=pos;i;i-=lowbit(i)){
res+=c[i];
res2+=c2[i];
}
return res*(pos+1)-res2;
}
int main(void){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
b[i]=a[i]-a[i-1];
update(i,b[i]);
}
for(int i=1;i<=m;++i){
scanf("%s",&ch);
int l,r,d;
if(ch[0]=='Q'){
scanf("%d%d",&l,&r);
// printf("%lld %lld\n",getsum(r),getsum(l-1));
printf("%lld\n",getsum(r)-getsum(l-1));
}else{
scanf("%d%d%d",&l,&r,&d);
update(l,d);
update(r+1,-d);
; }
}
}
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=100010;
int a[maxn],c[maxn],n,ans[maxn];
int lowbit(int x){
return x&(-x);
}
void update(int pos,int x){
for(int i=pos;i<n+10;i+=lowbit(i)){
c[i]+=x;
}
}
int getsum(int pos){
int res=0;
for(int i=pos;i;i-=lowbit(i)){
res+=c[i];
}
return res;
}
//找到第一个位置pos,是的getsum(pos)>=x
int findk(int x){
int l=1,r=n;
while(l<r){
int mid=(l+r)>>1,t=getsum(mid);
// cout<<mid<<" "<<t<<endl;
if(t<x) l=mid+1;
else r=mid;
}
return r;
}
int main(void){
scanf("%d",&n);
for(int i=1;i<=n;++i){
update(i,1);
}
for(int i=2;i<=n;++i){
scanf("%d",&a[i]);
}
for(int i=n;i>0;--i){
ans[i]=findk(a[i]+1);
// cout<<ans[i]<<endl;
update(ans[i],-1);
}
for(int i=1;i<=n;++i){
printf("%d\n",ans[i]);
}
}