树状数组
十年岐路,空负曲江花
介绍
树状数组就是以数组的形式来模拟树,在代码很简洁的情况下,能起到部分代替线段树的效果.而且此算法常数较小,空间开销也很小,是一个非常轻量级的数据结构.
思想
图中黑色的是原数组,红色的是树状数组.我们发现树状数组的空间开销是O(n)
的.我们以八个元素(从1开始记数会比较方便)求和为例.假设原数组是a[]
,树状数组是c[]
.
我们不难发现
- C[1] = A[1];
- C[2] = A[1] + A[2];
- C[3] = A[3];
- C[4] = A[1] + A[2] + A[3] + A[4];
- C[5] = A[5];
- C[6] = A[5] + A[6];
- C[7] = A[7];
- C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8];
即满足C[i] = A[i - 2^k +1] + A[i - 2^k +2] + ... + A[i] //k为i的二进制中从最低位到高位连续零的长度
的规律(我一上来还真没看出来,实际上这个规律就是c[i]
和 a[i- 2^k +1]
之间的规律)
其实问题最重要的是2^k
,我们记i
所对应的2^k
记为lowbit(i)
.
我们发现如果i+lowbit(i)==j
,则刚才的图中i
,j
之间一定有连线(比如2和4,3和4),如果2这个点有更新,则可以顺着这个边一直向右上方更新到8,也就是单点更新是O(logn)
级别的时间复杂度
而如果i-lowbit(i)==j
,那么Sum[i]=C[i]+Sum[j]
(比如图中的sum[6] == c[6]+sum[4]
),这样就可以用树状数组以O(logn)
的复杂度求出 1到i
的前缀和了.求出了前缀和我们就能在O(1)
时间内求出区间和 .
基础模板
刚才的分析已经很明显能够发现,树状数组是一个可以快速单点更新,区间求和的数据结构.以下代码中的tree[]
就是上文提到的c[]
.
lowbit()
int lowbit(int x) {
return x&(-x);
}
这个函数我们在数论里面遇到过,不再过多阐述,就是求从最低位到第一个1之间连续0的个数.
单点更新
void add(int x,int y) {
while(x<=n) {
tree[x]+=y;
x+=lowbit(x); // 沿着向上的线不断更新
}
}
区间求和
ll qzsum(int p) { // [1,p]的前缀和
int i=p;
ll ans=0ll;
while(i>0) {
ans+=tree[i];
i-=lowbit(i);
}
return ans;
}
ll sum(int l,int r) { // [l,r]区间的和,此处体现以1开头的优越性
return qzsum(r)-qzsum(l-1);
}
初始化
初始化其实很简单,就是先把树状数组清零,然后用单点更新即可.
memset(tree,0,sizeof(tree));
for(int i=1; i<=n; ++i) {
scanf("%d",&a[i]);
add(i,a[i]);
}
HDU 1166
// 直接模板套上即可
#include<bits/stdc++.h>
#define sc scanf
#define pr printf
#define pb push_back
#define ll long long
using namespace std;
int n;
int a[50000+100],tree[50000+100];
char cmd[20];
int lowbit(int x) {
return x&(-x);
}
void add(int x,int y) {
while(x<=n) {
tree[x]+=y;
x+=lowbit(x);
}
}
ll qzsum(int p) {
int i=p;
ll ans=0ll;
while(i>0) {
ans+=tree[i];
i-=lowbit(i);
}
return ans;
}
ll sum(int l,int r) {
return qzsum(r)-qzsum(l-1);
}
int main() {
int T;
sc("%d",&T);
int cnt=0;
while(T--) {
pr("Case %d:\n",++cnt);
sc("%d",&n);
memset(tree,0,sizeof(tree));
for(int i=1; i<=n; ++i) {
sc("%d",&a[i]);
add(i,a[i]);
}
while(sc("%s",cmd)) {
if(cmd[0]=='Q') {
int l,r;
sc("%d%d",&l,&r);
pr("%lld\n",sum(l,r));
} else if(cmd[0]=='A') {
int x,y;
sc("%d%d",&x,&y);
add(x,y);
} else if(cmd[0]=='S') {
int x,y;
sc("%d%d",&x,&y);
add(x,-1*y);
} else if(cmd[0]=='E') {
break;
}
}
}
return 0;
}
进阶问题
由于单点更新,单点询问可以用数组直接暴力,所以略过.
区间更新,单点询问
(此处的区间更新指的是区间加常数,这个问题的具体原理我们在差分数组中讲过了)
这个问题的关键在于如何降低区间更新的复杂度,如果直接朴素的O(n)
处理的话,显然超时,我们发现,这是对原数组的 区间更新,单点询问 ,对于其差分数组而言,其实是单点更新,区间询问 .在原数组区间[l,r]
上加常数K,相当于其差分数组p[l]-=k , p[r+1]-=k
,而单点询问就是求差分数组的前缀和.
所以我们干脆直接维护差分数组就好了!
差分数组我们命名为P[]
,定义a[0]=0 p[i]=a[i]-a[i-1]
以上模板**只需要改初始化即可,**如果询问a[i]
,则返回sum(i,i)
即可
初始化
memset(tree,0,sizeof(tree));
a[0]=0;
for(int i=1; i<=n; ++i) {
sc("%d",&a[i]);
add(i,a[i]-a[i-1]);
}
洛谷3368
#include<bits/stdc++.h>
#define sc scanf
#define pr printf
#define pb push_back
#define ll long long
using namespace std;
int n,m;
int a[500000+100],P[500000+100];
int lowbit(int x) {
return x&(-x);
}
void add(int x,int y) {
while(x<=n) {
P[x]+=y;
x+=lowbit(x);
}
}
ll qzsum(int x) {
ll ans=0ll;
while(x>0) {
ans+=P[x];
x-=lowbit(x);
}
return ans;
}
int main() {
int T;
//sc("%d",&T);
T=1;
while(T--) {
sc("%d%d",&n,&m);
memset(P,0,sizeof(P));
a[0]=0;
for(int i=1; i<=n; ++i) {
sc("%d",&a[i]);
add(i,a[i]-a[i-1]);
}
for(int i=0; i<m; ++i) {
int tp=0;
sc("%d",&tp);
if(tp==1) {
int x,y,k;
sc("%d%d%d",&x,&y,&k);
add(x,k); // 更新差分数组
add(y+1,-k);
} else {
int x;
sc("%d",&x);
pr("%lld\n",qzsum(x)); // 单点询问,其实这里也可以用上文中的sum(x,1)
}
}
}
return 0;
}
区间更新,区间询问
我们发现,区间求和,转化成对差分数组的操作是这样的.
∑
i
=
1
n
A
[
i
]
=
∑
i
=
1
n
∑
j
=
1
i
D
[
j
]
\sum_{i=1}^{n}A[i]=\sum_{i=1}^{n}\sum_{j=1}^{i}D[j]
i=1∑nA[i]=i=1∑nj=1∑iD[j]
但是直接算还是复杂度太大
根据参考资料(文首已给出链接)的推导过程
A[1]+A[2]+...+A[n]
= (D[1]) + (D[1]+D[2]) + ... + (D[1]+D[2]+...+D[n])
= n*D[1] + (n-1)*D[2] +... +D[n]
= n * (D[1]+D[2]+...+D[n]) - (0*D[1]+1*D[2]+...+(n-1)*D[n])
综上所述,
∑
i
=
1
n
A
[
i
]
=
n
∗
∑
i
=
1
n
D
[
j
]
−
∑
i
=
1
n
[
D
[
i
]
∗
(
i
−
1
)
]
\sum_{i=1}^{n}A[i]=n*\sum_{i=1}^{n}D[j] - \sum_{i=1}^{n}[D[i]*(i-1)]
i=1∑nA[i]=n∗i=1∑nD[j]−i=1∑n[D[i]∗(i−1)]
如果我们新声明一个数组Q[]
, Q[i] = D[i]*(i-1)
,则我们只需要同时维护D[]和Q[]
就好了.
单点更新
void add(int x,ll y) {
int tmp=x;
while(x<=n) {
P[x]+=y;
P2[x]+=(tmp-1)*y; // 注意此处很容易写错,增量在向右上传递过程中应该是不变的,都应该是最初的(x-1)*y
x+=lowbit(x);
}
}
区间询问
ll qzsum(int x) {
int tmp=x;
ll res=0ll;
while(x>0) {
res+=P[x]*tmp-P2[x]; // 这里当然也可以先算出两个和之后再作差,不过由于担心爆ll,此处就先作差再求和
x-=lowbit(x);
}
return res;
}
POJ3468
// #include<bits/stdc++.h> poj 不支持
#include<iostream>
#include<cstdio>
#include<cstring>
#define sc scanf
#define pr printf
#define pb push_back
#define ll long long
using namespace std;
int n,m;
ll a[500000+100],P[500000+100],P2[500000+100];
int lowbit(int x) {
return x&(-x);
}
void add(int x,ll y) {
int tmp=x;
while(x<=n) {
P[x]+=y;
P2[x]+=(tmp-1)*y;
x+=lowbit(x);
}
}
ll qzsum(int x) {
int tmp=x;
ll res=0ll;
while(x>0) {
res+=P[x]*tmp-P2[x];
x-=lowbit(x);
}
return res;
}
int main() {
while(sc("%d%d",&n,&m)!=EOF) {
memset(P,0,sizeof(P));
memset(P2,0,sizeof(P2));
a[0]=0;
for(int i=1; i<=n; ++i) {
sc("%lld",&a[i]);
add(i,a[i]-a[i-1]);
}
for(int i=0; i<m; ++i) {
char cmd[30];
sc("%s",cmd);
if(cmd[0]=='C') {
int x,y,k;
sc("%d%d%d",&x,&y,&k);
add(x,k);
add(y+1,-k);
} else {
int l,r;
sc("%d%d",&l,&r);
pr("%lld\n",qzsum(r)-qzsum(l-1));
}
}
}
return 0;
}
求逆序对
逆序对有两种求法:树状数组和归并排序,归并排序之后会再列出笔记说明.
思想
当数据范围不是很大的时候(1e8以下),其实很简单,我们只需要让之前代码中的a[]
数组变成类似桶排序的数组那样.a[i]
的值表示数值为i
的数在序列中出现了几次.此算法是在线算法,每次读入一个数存到tmp中,就调用add(tmp,1)
,然后此时调用qzsum(tmp-1)
,就得到当前数插入之前,比tmp
小的数有多少个,如果tmp
是第 i
个 插入的,那么前面一定有i-1-qzsum(tmp-1)
个 比 tmp大的,他们都会分别和tmp构成逆序对,每一次读入,都将这些数累积起来,最后就得到了所有的逆序数.
离散化
上一段中的思想是正确的,但是如果输入的数据范围稍大一些(超过1e8), 那么就会导致数组开不下,所以我们进行一下离散化处理.
我们发现,如果给的序列非常的稀疏(共5e5个数,范围是-1e9到1e9),那么我们就无需在意其绝对大小,而是更在乎其相对大小.
我们采取这样的策略:首先每个元素记录下其自身的值以及在原数组中的下标,然后按照它们自己的值进行排序(此处排序应当调用stable_sort(),或者在自定义的排序函数里面强调,如果值相同,下标小的排在前面),然后根据排序后的数组,将其下标依次插入树状数组即可.这种做法可行的原因是在根据值排序完成之后,其下标就能反应此元素插入的先后顺序,如果其前面有比自己下标大的,说明这就是一对逆序对.
当然,也可以将这些数按照值排序之后,再原路覆盖回去(以下标覆盖原值),不过比较麻烦,运用下标的方法更简单一些
洛谷1908
// 离散化
// 树状数组求逆序对
#include<bits/stdc++.h>
#define sc scanf
#define pr printf
#define pb push_back
#define ll long long
#define fi first
#define se second
#define pii pair<long long ,long long >
using namespace std;
#define MAXN 600000+100
int c[MAXN];
pii a[MAXN];
int n,m;
bool cmp(pii a,pii b) { // 值相同,按照下标排
if(a.fi==b.fi) {
return a.se<b.se;
}
if(a.fi < b.fi) {
return 1;
}
return 0;
}
int lowbit(int x) {
return x&(-x);
}
void add(int x,int y) {
while(x<=n) {
c[x]+=y;
x+=lowbit(x);
}
}
ll sum(int x) {
ll ans=0;
while(x>0) {
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
int main() {
sc("%d",&n);
for(int i=1; i<=n; ++i) {
sc("%lld",&a[i].first); // first 是其值
a[i].second=i; // second 是它插入的下标(或者说是时间戳也行)
}
sort(a+1,a+n+1,cmp);
ll ans=0;
for(int i=1; i<=n; ++i) {
add(a[i].se,1); // 将下标插入
ans+= i-1-sum(a[i].se-1);
}
pr("%lld\n",ans);
return 0;
}
求最长上升子序列
最长上升子序列虽然有个dp算法,但是用树状数组也是个很经典的算法
dp算法朴素是O(n^2)算法,但是可以用栈(不像是单调栈)和二分优化到O(nlogn),
树状数组求LIS的思想也是基于O(n^2)算法的一个优化
以i号元素结尾的最长上升子序列的长度 == max(以j号元素结尾的最长上升子序列的长度) (j是满足先于i插入,而且元素大小比i号元素小)
问题就在于给出了i,如何实现后面的操作,当然,区间最值也是可以用树状数组维护的.我们采取这样一种策略:先用另一个数组a[]
给val[]
数组排序(如果要求严格递增,则需要去重),然后遍历val[]
,对于每一个元素,都寻找其在a[]
中的下标id
,然后从树状数组中找dp[1 ~ id-1]
中的最大值,dp[i]表示以 a[i]元素结尾的最长上升子序列的长度
,dp数组相当于上文中的C数组.
找到之后,更新一下最长上升子序列的长度,然后更新dp[]
模板
#include<cstdio>
#include<algorithm>
#include<string.h>
using namespace std;
const int maxn=1e5+7;
int a[maxn],val[maxn],dp[maxn],n;
int query(int i) {
int s=0;
while(i>0) {
s=max(dp[i],s);
i-=i&-i;
}
return s;
}
void add(int i,int x) {
while(i<=n) {
dp[i]=max(dp[i],x);
i+=i&-i;
}
}
int main() {
while(~scanf("%d",&n)) {
for(int i=0; i<n; i++)
scanf("%d",&a[i]),val[i]=a[i];
sort(a,a+n);
int len=unique(a,a+n)-a;
memset(dp,0,sizeof(dp));
int ans=0,tmp;
for(int i=0; i<n; i++) {
int id=lower_bound(a,a+len,val[i])-a+1;
tmp=query(id-1)+1;
ans=max(ans,tmp);
add(id,tmp);
}
printf("%d\n",ans);
}
return 0;
}