树状数组(Fenwick Tree)
树状数组是一个查询和修改复杂度都为
O
(
l
o
g
n
)
O(logn)
O(logn)的数据结构。主要用于数组的单点修改、区间求和。
C[1]=A[1]
C[2]=A[1]+A[2]
C[3]=A[3]
C[4]=A[1]+A[2]+A[3]+A[4]
C[5]=A[5]
C[6]=A[5]+A[6]
C[7]=A[7]
C[8]=A[1]+A[2]+A[3]+A[4]+A[5]+A[6]+A[7]+A[8]
定义:
C
[
i
]
=
A
[
i
−
2
k
+
1
]
+
A
[
i
−
2
k
+
2
]
+
⋯
+
A
[
i
]
C[i]=A[i-2^k+1]+A[i-2^k+2]+\dots +A[i]
C[i]=A[i−2k+1]+A[i−2k+2]+⋯+A[i]
其中k为 i 末尾0的个数,设
l
o
w
b
i
t
(
i
)
=
2
k
lowbit(i)=2^k
lowbit(i)=2k
int lowbit(int t)
{
return t&(-t);
}
求和
s
u
m
[
7
]
=
s
u
m
[
111
]
=
C
[
111
]
+
C
[
110
]
+
C
[
100
]
sum[7]=sum[111]=C[111]+C[110]+C[100]
sum[7]=sum[111]=C[111]+C[110]+C[100]
每次的i都不同,每次都去掉
l
o
w
b
i
t
(
i
)
lowbit(i)
lowbit(i),即每次都去掉最低位的1
即
s
u
m
[
7
]
=
C
[
7
]
+
C
[
6
]
+
C
[
4
]
sum[7]=C[7]+C[6]+C[4]
sum[7]=C[7]+C[6]+C[4]
int getsum(int i)
{
int ret=0;
while(i>0)
{
ret+=C[i];
i-=lowbit(i)
}
return ret;
}
单点更新
A[5]=A[101]包含在C[101],C[110],C[1000]中,每次加上最低位的1的大小
l
o
w
b
i
t
(
i
)
lowbit(i)
lowbit(i),每次的i都不一定相同,lowbit(i)也会相应变化
在 i 的位置上加上x
void add(int i,int x)
{
while(i<=n)
{
c[i]+=x;
i+=lowbit(i);
}
}
1、单点更新、区间查询
对原数组建立树状数组
int N,a[maxn],C[maxn];
int lowbit(int i)
{
return i&(-i);
}
int getsum(int i)
{
int ret=0;
while(i>0)
{
ret+=C[i];
i-=lowbit(i);
}
return ret;
}
void add(int i,int x)
{
while(i<=N)
{
C[i]+=x;
i+=lowbit(i);
}
}
int main()
{
scanf("%d",&N);
rep(i,1,N)
{
scanf("%d",&a[i]);
add(i,a[i]);
}
//查询区间[l,r]的和
int l,r;
scanf("%d %d",&l,&r);
int ans=getsum(r)-getsum(l-1);
printf("%d\n",ans);
return 0;
}
2、区间更新、单点查询
对差分数组建立树状数组
差分数组前缀和getsum(i)就是原数组更新之后的位置为i的值
对区间 [ l , r ] [l,r] [l,r] 的值加上 k k k,并查询位置p的值
int T,N,a[maxn],C[maxn];
int lowbit(int i)
{
return i&(-i);
}
int getsum(int i)
{
int ret=0;
while(i>0)
{
ret+=C[i];
i-=lowbit(i);
}
return ret;
}
void add(int i,int x)
{
while(i<=N)
{
C[i]+=x;
i+=lowbit(i);
}
}
int main()
{
scanf("%d",&N);
rep(i,1,N)
{
scanf("%d",&a[i]);
add(i,a[i]-a[i-1]);
}
//对差分数组做更新
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
add(l,k);
add(r+1,-k);
//查询单点p
int p;
scanf("%d",&p);
int ans=getsum(p);
printf("%d\n",ans);
return 0;
}
3、区间更新、区间查询
A
[
1
]
+
A
[
2
]
+
⋯
+
A
[
n
]
A[1]+A[2]+\dots+A[n]
A[1]+A[2]+⋯+A[n]
=
(
D
[
1
]
)
+
(
D
[
1
]
+
D
[
2
]
)
+
⋯
+
(
D
[
1
]
+
D
[
2
]
+
⋯
+
D
[
n
]
)
=(D[1])+(D[1]+D[2])+\dots+(D[1]+D[2]+\dots+D[n])
=(D[1])+(D[1]+D[2])+⋯+(D[1]+D[2]+⋯+D[n])
=
n
D
[
1
]
+
(
n
−
1
)
D
[
2
]
+
⋯
+
2
D
[
n
−
1
]
+
D
[
n
]
=nD[1]+(n-1)D[2]+\dots+2D[n-1]+D[n]
=nD[1]+(n−1)D[2]+⋯+2D[n−1]+D[n]
=
n
(
D
[
1
]
+
D
[
2
]
+
⋯
+
D
[
n
]
)
−
(
D
[
2
]
+
2
D
[
3
]
+
⋯
+
(
n
−
1
)
D
[
n
]
)
=n(D[1]+D[2]+\dots+D[n])-(D[2]+2D[3]+\dots+(n-1)D[n])
=n(D[1]+D[2]+⋯+D[n])−(D[2]+2D[3]+⋯+(n−1)D[n])
化简得
∑
i
=
1
n
A
[
i
]
=
n
∑
i
=
1
n
D
[
i
]
−
∑
i
=
1
n
D
[
i
]
×
(
i
−
1
)
\sum_{i=1}^{n}A[i]=n\sum_{i=1}^nD[i]-\sum_{i=1}^nD[i]\times (i-1)
i=1∑nA[i]=ni=1∑nD[i]−i=1∑nD[i]×(i−1)
因此,维护两个树状数组
第一个维护
D
[
1
]
+
D
[
2
]
+
⋯
+
D
[
n
]
D[1]+D[2]+\dots+D[n]
D[1]+D[2]+⋯+D[n]
第二个维护
0
×
D
[
1
]
+
1
×
D
[
2
]
+
2
×
D
[
3
]
+
⋯
+
(
n
−
1
)
×
D
[
n
]
0\times D[1]+1\times D[2]+2\times D[3]+\dots+(n-1)\times D[n]
0×D[1]+1×D[2]+2×D[3]+⋯+(n−1)×D[n]
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <string>
#include <cmath>
#define rep(i,a,b) for (int i=a; i<=b; ++i)
#define per(i,b,a) for (int i=b; i>=a; --i)
#define mes(a,b) memset(a,b,sizeof(a))
#define mp make_pair
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define pll pair<ll,ll>
#define ls (rt<<1)
#define rs ((rt<<1)|1)
#define isZero(d) (abs(d) < 1e-8)
using namespace std;
const int maxn=5e4+5,INF=0x3f3f3f3f;
const int mod=1e9+7;
int T,N,a[maxn],C1[maxn],C2[maxn];
int lowbit(int i)
{
return i&(-i);
}
/*
我们要求的是A[1-->n]的和,即nD[1-->i]-(D[2]+2D[3]+...+(n-1)D[n])
这个n只有我们在求的时候才能确定,所以树状数组只能维护D[1-->n],
而不是nD[1],nD[2],...nD[n],因为我们一开始并不知道这个n
我们只有在计算的时候才能乘上这个n
*/
int getsum(int i)
{
int ret=0,i2=i;
while(i>0)
{
ret+=i2*C1[i]-C2[i];
i-=lowbit(i);
}
return ret;
}
/*在原数组(i-1)*D[i]这个位置加上x,则树状数组向上更新时每个含有该元素的
树状数组中,都要加上(i-1)*x
*/
void add(int i,int x)
{
int i2=i;
while(i<=N)
{
C1[i]+=x;
C2[i]+=x*(i2-1);
i+=lowbit(i);
}
}
int main()
{
scanf("%d",&N);
rep(i,1,N)
{
scanf("%d",&a[i]);
add(i,a[i]-a[i-1]);
}
int l,r,k;
//区间[l,r]的值加上k
scanf("%d%d%d",&l,&r,&k);
add(l,k);
add(r+1,-k);
//求区间[l,r]的和
while(~scanf("%d%d",&l,&r))
{
int ans=getsum(r)-getsum(l-1);
printf("%d\n",ans);
}
return 0;
}
/*
5
1 2 3 4 5
2 4 10
1 5
*/
模板题
敌兵布阵 HDU - 1166
题型:单点更新、区间查询
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <string>
#include <cmath>
#define rep(i,a,b) for (int i=a; i<=b; ++i)
#define per(i,b,a) for (int i=b; i>=a; --i)
#define mes(a,b) memset(a,b,sizeof(a))
#define mp make_pair
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define pll pair<ll,ll>
#define ls (rt<<1)
#define rs ((rt<<1)|1)
#define isZero(d) (abs(d) < 1e-8)
using namespace std;
const int maxn=5e4+5,INF=0x3f3f3f3f;
const int mod=1e9+7;
int T,N,a[maxn],C[maxn];
int lowbit(int i)
{
return i&(-i);
}
int getsum(int i)
{
int ret=0;
while(i>0)
{
ret+=C[i];
i-=lowbit(i);
}
return ret;
}
void add(int i,int x)
{
while(i<=N)
{
C[i]+=x;
i+=lowbit(i);
}
}
int main()
{
scanf("%d",&T);
int Cas=0;
while(T--)
{
scanf("%d",&N);
rep(i,1,N)
a[i]=C[i]=0;
rep(i,1,N)
{
scanf("%d",&a[i]);
add(i,a[i]);
}
char op[10];
printf("Case %d:\n",++Cas);
while(~scanf("%s",op))
{
if(op[0]=='E')
break;
else if(op[0]=='Q')
{
int l,r;
scanf("%d %d",&l,&r);
int ans=getsum(r)-getsum(l-1);
printf("%d\n",ans);
}
else if(op[0]=='A')
{
int x,y;
scanf("%d %d",&x,&y);
add(x,y);
}
else if(op[0]=='S')
{
int x,y;
scanf("%d %d",&x,&y);
add(x,-y);
}
}
}
return 0;
}
A Simple Problem with Integers POJ - 3468
题型:区间更新、区间查询
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <string>
#include <cmath>
#define rep(i,a,b) for (int i=a; i<=b; ++i)
#define per(i,b,a) for (int i=b; i>=a; --i)
#define mes(a,b) memset(a,b,sizeof(a))
#define mp make_pair
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define pll pair<ll,ll>
#define ls (rt<<1)
#define rs ((rt<<1)|1)
#define isZero(d) (abs(d) < 1e-8)
using namespace std;
const int maxn=1e5+5,INF=0x3f3f3f3f;
const int mod=1e9+7;
ll N,Q,a[maxn],C1[maxn],C2[maxn];
ll lowbit(ll i)
{
return i&(-i);
}
void add(ll i,ll x)
{
ll i2=i;
while(i<=N)
{
C1[i]+=x;
C2[i]+=x*(i2-1);
i+=lowbit(i);
}
}
ll getsum(ll i)
{
ll ret=0,i2=i;
while(i>0)
{
ret+=i2*C1[i]-C2[i];
i-=lowbit(i);
}
return ret;
}
int main()
{
while(~scanf("%d %d",&N,&Q))
{
rep(i,1,N)
{
scanf("%lld",&a[i]);
add(i,a[i]-a[i-1]);
}
while(Q--)
{
char op[10];
int l,r;
scanf("%s %d %d",op,&l,&r);
if(op[0]=='Q')
{
ll ans=getsum(r)-getsum(l-1);
printf("%lld\n",ans);
}
else if(op[0]=='C')
{
int k;
scanf("%d",&k);
add(l,k);
add(r+1,-k);
}
}
}
return 0;
}