一、树状数组介绍
树状数组是一种用数组模拟的树形结构,修改和查询的复杂度都是O(logN),常用于解决区间更新以及求和问题.
上图中:
C[1] = A[1];
C[2] = A[1] + A[2];
C[3] = A[3];
C[4] = A[1] + A[2] + A[3] + A[4];
C[5] = A[5];
C[6] = A[5] + A[6];
C[7] = A[7];
C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8];
由此可得这颗树的规律:
C[i]=A[i-
2
k
2^k
2k+1]+A[i-
2
k
2^k
2k+2]+…+A[i],其中k为i的二进制中从最低位到高位连续零的长度,
2
k
2^k
2k也叫lowbit,可以利用位运算计算,如下
int lowbit(int x){
return x&(-x);
}
二、代码
树状数组的单点修改,单点查询和区间查询的完整代码如下:
#include<iostream>
#include<algorithm>
#include<queue>
#include<set>
#include<stdlib.h>
#include<time.h>
#include<unordered_map>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<string>
#include<map>
#include<cmath>
#include<bitset>
#define ll long long
#define inf 0x3f3f3f3f
#define bug(a) cout<<"* "<<a<<endl;
#define bugg(a,b) cout<<"* "<<a<<" "<<b<<endl;
#define buggg(a,b,c) cout<<"* "<<a<<" "<<b<<" "<<c<<endl;
using namespace std;
const int N=2e6+10;
const ll mod=1e9+7;
typedef pair<double,double> P;
int n;
int A[N],c[N];//对应原数组和树状数组
int lowbit(int x){
return x&(-x);
}
void updata(int i,int k){ //在i位置加上k
while(i<=n){
c[i]+=k;
i+=lowbit(i);
}
}
int getsum(int i){ //求前i项的和
int res=0;
while(i>0){
res+=c[i];
i-=lowbit(i);
}
return res;
}
int main(){
cin>>n;
for(int i=1;i<=n;i++){
cin>>A[i];
updata(i,A[i]);//第i个位置加上A[i];
}
cout<<getsum(3)<<endl;//求前三项的和
updata(2,2);//把第二个元素加上2
cout<<getsum(3)<<endl;
cout<<getsum(4)-getsum(1)<<endl;//求区间2~4之间的和
cout<<getsum(2)-getsum(1)<<endl;//求第二个元素的值
return 0;
}
区间修改,单点查询则需要借助差分数组,完整代码如下:
#include<iostream>
#include<algorithm>
#include<queue>
#include<set>
#include<stdlib.h>
#include<time.h>
#include<unordered_map>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<string>
#include<map>
#include<cmath>
#include<bitset>
#define ll long long
#define inf 0x3f3f3f3f
#define bug(a) cout<<"* "<<a<<endl;
#define bugg(a,b) cout<<"* "<<a<<" "<<b<<endl;
#define buggg(a,b,c) cout<<"* "<<a<<" "<<b<<" "<<c<<endl;
using namespace std;
const int N=2e6+10;
const ll mod=1e9+7;
typedef pair<double,double> P;
int n,m;
ll A[N],c[N];//对应原数组和树状数组
int lowbit(int x){
return x&(-x);
}
void updata(int i,ll k){ //在i位置加上k
while(i<=n){
c[i]+=k;
i+=lowbit(i);
}
}
ll getsum(int i){ //求d数组前i项的和,也就是A[i]的值
ll res=0;
while(i>0){
res+=c[i];
i-=lowbit(i);
}
return res;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
cin>>A[i];
updata(i,A[i]-A[i-1]);//第i个位置加上d[i];
}
while(m--){
char opt;
cin>>opt;
if(opt=='Q'){
int x;
cin>>x;
cout<<getsum(x)<<endl;
}
if(opt=='C'){
int l,r,d;
cin>>l>>r>>d;//把区间l~r之间每个元素都加上d
updata(l,d);
updata(r+1,-d);
}
}
return 0;
}
区间修改,区间查询,原理如下:
设d[i]为原数组的差分数组,则可推导出原数组前n项和满足下面公式:
∑
i
=
1
n
a
[
i
]
=
∑
i
=
1
n
d
[
i
]
−
∑
i
=
1
n
(
i
−
1
)
∗
d
[
i
]
\sum_{i=1}^{n}a[i]=\sum_{i=1}^{n}d[i]-\sum_{i=1}^{n}(i-1)*d[i]
i=1∑na[i]=i=1∑nd[i]−i=1∑n(i−1)∗d[i]
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e6+10;
ll n,a[N],sum1[N],sum2[N];
//sum1[i]:d[i],sum2[i]:(i-1)*d[i]
int lowbit(int x){
return x&-x;
}
void updata(int i,ll x){
//单点修改
ll p1=i;
while(i<=n){
sum1[i]+=x;
sum2[i]+=x*(p1-1);
i+=lowbit(i);
}
}
void range_up(int l,int r,ll x){
//使区间l到r中每一个数都加上x
updata(l,x);
updata(r+1,-x);
}
ll getsum(int i){
//求前i项和
ll res=0,p=i;
while(i>0){
res+=p*sum1[i]-sum2[i];
i-=lowbit(i);
}
return res;
}
ll range_sum(int l,int r){
//区间求和
return getsum(r)-getsum(l-1);
}
int main()
{
cin>>n;
ll x;
for(int i=1;i<=n;i++){
cin>>a[i];
updata(i,a[i]-a[i-1]);
}
cout<<getsum(2)<<endl;
range_up(2,4,1);
cout<<getsum(2)<<endl;
range_up(3,4,-1);
cout<<range_sum(1,4)<<endl;
return 0;
}