树状数组
树状数组是一个查询和修改复杂度都为
l
o
g
(
n
)
log(n)
log(n)的数据结构,而且是一个在线的数据结构,支持随时修改某个元素的值,复杂度也为
l
o
g
log
log级别。
不难发现,现在编号为
i
i
i的节点,
c
[
i
]
表
示
的
其
实
是
[
i
−
2
k
+
1
,
i
]
c[i]表示的其实是[i-2^k+1,i]
c[i]表示的其实是[i−2k+1,i]这个区间的和,即:
c
[
i
]
=
a
[
i
−
2
k
+
1
]
+
…
+
a
[
i
]
c[i]=a[i-2^k+1]+…+a[i]
c[i]=a[i−2k+1]+…+a[i]。其中
k
k
k是
i
i
i的二进制表示中末尾0的个数,同时也是这个节点在树中的高度。有了这个性质,
B
I
T
BIT
BIT的更新与求和就可以利用二进制运算非常简单的快速实现。
注意:当问题不满足减法原则(区间可以用后面减前面来表示,比如前缀和就满足,但最大最小值就不满足)时,只能用线段树不能用树状数组。
树状数组的修改其实也就是坐标在二进制下的修改,建议写个序列手动画几遍基本操作,好理解很多。
应用
- 单点更新,区间求和
个人喜欢封装写法
struct BIT{
int c[N];
void clear(){memset(c,0,sizeof(c));}
//计算2^二进制下x末尾0的个数
int lowbit(int x){return x&(-x);}
//单点修改
void add(int x,int v){for(;x<=n;x+=lowbit(x))c[x]+=v;}
//区间求和
int query(int x){int ans=0;for(;x;x-=lowbit(x)){ans+=c[x];}return ans;}
}bit;
例题:敌兵布阵(模板题)
#include <bits/stdc++.h>
using namespace std;
const int N=(int)1e5+50;
int n,a[N],t;
char c[N];
struct BIT{
int c[N];
void clear(){memset(c,0,sizeof(c));}
int lowbit(int x){return x&(-x);}
void add(int x,int v){for(;x<=n;x+=lowbit(x))c[x]+=v;}
int query(int x){int ans=0;for(;x;x-=lowbit(x)){ans+=c[x];}return ans;}
}bit;
inline int read(){
int cnt=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){cnt=(cnt<<1)+(cnt<<3)+(c^48);c=getchar();}
return cnt*f;
}
int main(){
t=read();
for(int cas=1;cas<=t;++cas){
n=read();bit.clear();printf("Case %d:\n",cas);
for(int i=1;i<=n;++i) a[i]=read(),bit.add(i,a[i]);
while(scanf("%s",c),strcmp(c,"End")){
int x=read(),y=read();
if(c[0]=='Q') cout<<bit.query(y)-bit.query(x-1)<<endl;
else if(c[0]=='S') bit.add(x,-y);
else bit.add(x,y);
}
}
return 0;
}
- 区间更新,单点求值
差分思想,即在左区间操作一次,在右区间的下一个操作一次,相当于撤销这个操作即可。
例题:Color the ball
注意这道题每行输出后不能有空格,要不然会 p r e s e n t a t i o n e r r o r presentation\ error presentation error
#include <bits/stdc++.h>
using namespace std;
const int N=(int)1e5+50;
int n;
struct BIT{
int c[N];
void clear(){memset(c,0,sizeof(c));}
inline int lowbit(int x){return x&(-x);}
inline void add(int x,int val){for(;x<=n;x+=lowbit(x))c[x]+=val;}
inline int query(int x){int ans=0;for(;x;x-=lowbit(x))ans+=c[x]; return ans;}
}bit;
inline int read(){
int cnt=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){cnt=(cnt<<1)+(cnt<<3)+(c^48);c=getchar();}
return cnt*f;
}
int main(){
while(n=read(),n){
bit.clear();
for(int i=1;i<=n;++i){int x=read(),y=read();bit.add(x,1),bit.add(y+1,-1);}
for(int i=1;i<n;++i){
cout<<bit.query(i)<<' ';
}
cout<<bit.query(n);
cout<<endl;
}
return 0;
}
- 逆序对
利用树状数组求逆序对,本质上还是属于“单点更新,区间求和”类型。
树状数组的 c [ x ] c[x] c[x]的下标 x x x看做是要读入的序列里面的一个值, c [ x ] c[x] c[x]表示 x x x出现的次数。
那么读入一个数 x x x就将 e [ x ] e[x] e[x]以及后面和 x x x相关联( l o w b i t lowbit lowbit计算)的每个数 c [ x + l o w b i t [ x ] ] c[x+lowbit[x]] c[x+lowbit[x]]
都要加一个 1 1 1,因为它比它后面的每一个数都要小。而查询操作 q u e r y ( x ) query(x) query(x)返回值的 S u m ( x ) Sum(x) Sum(x),就刚好表示前面小于 x x x的个数。
假设序列中第 i i i个元素的值为 x x x,则前i个元素中比 x x x大的元素的个数为 i − q u e r y ( x ) i-query(x) i−query(x),逆序对的求法就很明显了。
例题:光荣的梦想
离散化+开 l o n g l o n g long\ long long long
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=(int)1e5+50;
int n,a[N],b[N],c[N],ans=0;
inline int read(){
int cnt=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){cnt=(cnt<<1)+(cnt<<3)+(c^48);c=getchar();}
return cnt*f;
}
struct BIT{
int c[N];
inline int lowbit(int x){return x&(-x);}
inline void clear(){memset(c,0,sizeof(c));}
inline void add(int x,int v){for(;x<=n;x+=lowbit(x))c[x]+=v;}
inline int query(int x){int ans=0;for(;x;x-=lowbit(x))ans+=c[x];return ans;}
}bit;
signed main(){
// freopen("1.in","r",stdin);
n=read();
for(int i=1;i<=n;++i){
a[i]=b[i]=read();
}
sort(b+1,b+1+n);
for(int i=1;i<=n;++i){
c[i]=lower_bound(b+1,b+1+n,a[i])-b;
}
for(int i=1;i<=n;++i){
bit.add(c[i],1);
ans=ans+(i-bit.query(c[i]));
}
cout<<ans;
return 0;
}
例题:
C
o
w
s
o
r
t
i
n
g
Cow\ sorting
Cow sorting
一句话题意:求一个序列里面所有逆序对的数字和。
在BIT里记录数字和与逆序对个数即可,注意开
l
o
n
g
l
o
n
g
long\ long
long long
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=(int)1e6+50;
int n,a[N];
inline int read(){
int cnt=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){cnt=(cnt<<1)+(cnt<<3)+(c^48);c=getchar();}
return cnt*f;
}
struct BIT{
int sum[N],num[N];
void clear(){memset(sum,0,sizeof(sum));memset(num,0,sizeof(num));}
inline int lowbit(int x){return x&(-x);}
inline void add(int x,int v,int cnt){for(;x<=n;x+=lowbit(x))sum[x]+=v,num[x]+=cnt;}
inline int query_sum(int x){int ans=0;for(;x;x-=lowbit(x))ans+=sum[x];return ans;}
inline int query_num(int x){int ans=0;for(;x;x-=lowbit(x))ans+=num[x];return ans;}
}bit;int gu;
signed main(){
while(scanf("%d",&n)!=EOF){
int ans=0;
bit.clear();
for(int i=1;i<=n;++i){
a[i]=read();gu+=a[i];
bit.add(a[i],a[i],1);
int k1=i-bit.query_num(a[i]);
if(k1){
int k2=gu-bit.query_sum(a[i]);
ans+=k1*a[i]+k2;
}
}
cout<<ans<<endl;
}
return 0;
}