本文介绍树状数组的三类基础用法。
1.单点修改,区间查询。
2.区间修改,单点查询。
3.区间修改,区间查询。
树状数组的建立所需时间复杂度为 O ( n log n ) O(n\log n) O(nlogn),空间复杂度为 O ( n ) O(n) O(n),每一次操作的时间复杂度为 O ( log n ) O(\log n) O(logn),空间复杂度为 O ( 1 ) O(1) O(1)。因其代码简单,常数较小,是比较优秀的数据结构之一。
下图为树状数组模板代码。
struct Binary_Inedxed_Tree{
ll ans[1000010];
inline int lowbit(int x){
return x&(-x);
}
inline void add(int pos,ll val,ll x[]){
for(register int i=pos;i<=n;i+=lowbit(i)){
x[i]+=val;
}
}
inline ll query(int pos,ll x[]){
ll cnt=0;
for(register int i=pos;i;i-=lowbit(i)){
cnt+=x[i];
}
return cnt;
}
}a;
结合代码,下面介绍三种树状数组的基础用法:
1.单点修改,区间查询。
ans数组在此时作为前缀和数组,原理较为简单。
2.区间修改,单点查询。
ans数组在此时作为差分数组。
在区间修改时,虽然树状数组会在被修改区间的左端点l的右边继续加上被修改的值,但由lowbit运算保证了每一次差分在最终查询单点值时只会被计算一次。
举个例子,区间[3,6]加1,对应代码为add(3,1),add(7,-1),此时add(3,1)时,ans[3]与ans[4]都加上1,但在查询i=5处的数值时,先lowbit到i=4,取出此处的差分值,然后再次进行i-lowbit(i),则此时位置为0,故3处对应的差分值未被重复计算。
3.区间修改,区间查询。
始终牢记,树状数组的本质是单点修改,区间查询。因此着三个基本操作都离不开单点修改,在区间修改时,单点修改只能使用差分数组,故设ans[i]为差分数组,则最终i处的数值为
∑
1
i
a
n
s
[
i
]
\sum_{1}^{i}ans[i]
∑1ians[i],设数组val[i]表示i的前缀和,则有
v
a
l
[
i
]
=
∑
j
=
1
i
v
a
l
[
j
]
=
∑
j
=
1
i
(
i
−
j
+
1
)
a
n
s
[
j
]
val[i]=\sum_{j=1}^{i}val[j]=\sum_{j=1}^{i}(i-j+1)ans[j]
val[i]=j=1∑ival[j]=j=1∑i(i−j+1)ans[j]
但在树状数组中,当
a
n
s
[
i
]
ans[i]
ans[i]改变时,无法快速维护每一个
v
a
l
[
i
]
val[i]
val[i]的改变,因为这时
v
a
l
[
i
]
val[i]
val[i]的值不仅与
a
n
s
[
j
]
ans[j]
ans[j]有关,还与
i
和
j
i和j
i和j之间的距离有关。
因此,可以将上式转化为
v
a
l
[
i
]
=
n
∗
∑
j
=
1
i
a
n
s
[
j
]
−
∑
j
=
1
i
(
j
−
1
)
a
n
s
[
j
]
val[i]=n*\sum_{j=1}^{i}ans[j]-\sum_{j=1}^{i}(j-1)ans[j]
val[i]=n∗j=1∑ians[j]−j=1∑i(j−1)ans[j]
此时只需要多开一个 s u m sum sum数组, s u m [ i ] = ( i − 1 ) ∗ a n s [ i ] sum[i]=(i-1)*ans[i] sum[i]=(i−1)∗ans[i],在修改时,当 a n s [ j ] ans[j] ans[j]改变,只需要在树状数组维护好 s u m [ i ] 和 a n s [ i ] sum[i]和ans[i] sum[i]和ans[i]的前缀和即可。
例题:Loj132
#include<cstdio>
using namespace std;
typedef long long ll;
int q,n;
struct Binary_Inedxed_Tree{
ll ans[1000010],sum[1000010];
inline int lowbit(int x){
return x&(-x);
}
inline void add(int pos,ll val,ll x[]){
for(register int i=pos;i<=n;i+=lowbit(i)){
x[i]+=val;
}
}
inline ll query(int pos,ll x[]){
ll cnt=0;
for(register int i=pos;i;i-=lowbit(i)){
cnt+=x[i];
}
return cnt;
}
}a;
int main(){
scanf("%d%d",&n,&q);
for(register int i=1;i<=n;i++){
ll x;
scanf("%lld",&x);
a.add(i,x,a.ans);
a.add(i+1,-x,a.ans);
a.add(i,(i-1)*x,a.sum);
a.add(i+1,-i*x,a.sum);
}
for(register int i=1;i<=q;i++){
int opt;
scanf("%d",&opt);
if(opt==1){
int l,r;ll x;
scanf("%d%d%lld",&l,&r,&x);
a.add(l,x,a.ans);a.add(r+1,-x,a.ans);
a.add(l,x*(l-1),a.sum);a.add(r+1,-x*r,a.sum);
}
if(opt==2){
int l,r;
scanf("%d%d",&l,&r);
ll ans1=r*a.query(r,a.ans)-(l-1)*a.query(l-1,a.ans);
ll ans2=a.query(r,a.sum)-a.query(l-1,a.sum);
printf("%lld\n",ans1-ans2);
}
}
}