文章目录
算法简介
树状数组 ( B i n a r y I n d e x e d T r e e ( B . I . T ) , F e n w i c k T r e e ) (Binary Indexed Tree(B.I.T), Fenwick Tree) (BinaryIndexedTree(B.I.T),FenwickTree)是一个查询和修改复杂度都为 l o g ( n ) log(n) log(n)的数据结构,主要用于查询任意两位之间的所有元素之和,经过修改也可以用于其他满足可加性的操作(比方说异或和)。树状数组和线段树很像,但能用树状数组解决的问题,基本上都能用线段树解决,而线段树能解决的树状数组不一定能解决。相比较而言,树状数组效率要高很多,代码长度也要短很多。
算法实现
1. 单点修改 + 区间查询
2. 区间修改 + 单点查询
通过 “差分”(就是记录数组中每个元素与前一个元素的差),可以把这个问题转化为问题 1 1 1。
查询
设原数组为
a
[
i
]
a[i]
a[i], 设数组
d
[
i
]
=
a
[
i
]
−
a
[
i
−
1
]
(
a
[
0
]
=
0
)
d[i]=a[i]−a[i−1](a[0]=0)
d[i]=a[i]−a[i−1](a[0]=0),则
a
[
i
]
=
∑
j
=
1
i
=
d
[
j
]
a[i]=\sum_{j=1}^i=d[j]
a[i]=∑j=1i=d[j],可以通过求
d
[
i
]
d[i]
d[i]的前缀和查询。
修改
当给区间
[
l
,
r
]
[l,r]
[l,r]加上
x
x
x的时候,
a
[
l
]
a[l]
a[l] 与前一个元素
a
[
l
−
1
]
a[l−1]
a[l−1] 的差增加了
x
x
x,
a
[
r
+
1
]
a[r+1]
a[r+1] 与
a
[
r
]
a[r]
a[r]的差减少了
x
x
x。根据
d
[
i
]
d[i]
d[i]数组的定义,只需给
d
[
l
]
d[l]
d[l] 加上
x
x
x, 给
d
[
r
+
1
]
d[r+1]
d[r+1]减去
x
x
x 即可。
代码实现如下:
3. 区间修改 + 区间查询
这是最常用的部分,也是用线段树写着最麻烦的部分——但是现在我们有了树状数组!
怎么求呢?我们基于问题2的 “差分” 思路,考虑一下如何在问题 2 2 2构建的树状数组中求前缀和:
位置 p p p的前缀和 = ∑ i = 1 p a [ i ] = ∑ i = 1 p ∑ j = 1 i d [ j ] \sum_{i=1}^p a[i]=\sum_{i=1}^p\sum_{j=1}^id[j] i=1∑pa[i]=i=1∑pj=1∑id[j]
在等式最右侧的式子 ∑ i = 1 p ∑ j = 1 i d [ j ] \sum_{i=1}^p\sum_{j=1}^id[j] ∑i=1p∑j=1id[j]中, d [ 1 ] d[1] d[1] 被用了p次, d [ 2 ] d[2] d[2]被用了 p − 1 p−1 p−1次……那么我们可以写出:
位置
p
p
p的前缀和 =
∑
i
=
1
p
∑
j
=
1
i
d
[
j
]
=
∑
i
=
1
p
d
[
i
]
∗
(
p
−
i
+
1
)
=
(
p
+
1
)
×
∑
i
=
1
p
d
[
i
]
−
∑
i
=
1
p
(
d
[
i
]
×
i
)
\sum_{i=1}^p\sum_{j=1}^i d[j]=\sum_{i=1}^pd[i]∗(p−i+1)=(p+1)\times \sum_{i=1}^pd[i]−\sum_{i=1}^p(d[i]\times i)
i=1∑pj=1∑id[j]=i=1∑pd[i]∗(p−i+1)=(p+1)×i=1∑pd[i]−i=1∑p(d[i]×i)
那么我们可以维护两个数组的前缀和:
一个数组是
s
u
m
1
[
i
]
=
∑
j
=
1
i
d
[
j
]
sum1[i]=\sum_{j=1}^id[j]
sum1[i]=∑j=1id[j],
另一个数组是
s
u
m
2
[
i
]
=
∑
j
=
1
i
(
d
[
j
]
×
j
)
sum2[i]=\sum_{j=1}^i(d[j]\times j)
sum2[i]=∑j=1i(d[j]×j)。
查询
位置
p
p
p的前缀和即:
(
p
+
1
)
×
s
u
m
1
(p + 1) \times sum1
(p+1)×sum1数组中p的前缀和
−
s
u
m
2
-\ sum2
− sum2数组中
p
p
p的前缀和。
区间
[
l
,
r
]
[l, r]
[l,r]的和即:位置
r
r
r的前缀和 - 位置
l
l
l的前缀和。
修改
对于
s
u
m
1
sum1
sum1数组的修改同问题
2
2
2中对
d
d
d数组的修改。
对于
s
u
m
2
sum2
sum2数组的修改也类似,我们给
s
u
m
2
[
l
]
sum2[l]
sum2[l] 加上
l
×
x
l \times x
l×x,给
s
u
m
2
[
r
+
1
]
sum2[r + 1]
sum2[r+1] 减去
(
r
+
1
)
×
x
(r + 1) \times x
(r+1)×x。
代码实现如下:LOJ #132. 树状数组 3 :区间修改,区间查询
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=1e6+5;
int n,m;
ll a[MAXN],sum1[MAXN],sum2[MAXN];
inline ll read()
{
ll X=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+(ch^'0'); ch=getchar();}
return X*f;
}
inline ll lowbit(int x)
{
return x&(-x);
}
inline void add(int p,ll x)
{
for(int i=p;i<=n;i+=lowbit(i))
sum1[i]+=x,sum2[i]+=x*p;
}
inline void addRange(int l,int r,ll x)
{
add(l,x),add(r+1,-x);
}
inline ll ask(int p)
{
ll res=0;
for(int i=p;i;i-=lowbit(i))
res+=(p+1)*sum1[i]-sum2[i];
return res;
}
inline ll askRange(int l,int r)
{
return ask(r)-ask(l-1);
}
int main()
{
// freopen("input.txt","r",stdin);
n=read(); m=read();
for(int i=1;i<=n;++i)
a[i]=read(),add(i,a[i]-a[i-1]);
for(int i=1;i<=m;++i)
{
int op=read(),l=read(),r=read(); ll x;
if(op==1) x=read(),addRange(l,r,x);
else if(op==2) printf("%lld\n",askRange(l,r));
}
return 0;
}
4. 二维树状数组
我们已经学会了对于序列的常用操作,那么我们不由得想到(谁会想到啊喂)……能不能把类似的操作应用到矩阵上呢?这时候我们就要写二维树状数组了!
在一维树状数组中,
s
u
m
[
x
]
sum[x]
sum[x](树状数组中的那个“数组”)记录的是右端点为
x
x
x、长度为
l
o
w
b
i
t
(
x
)
lowbit(x)
lowbit(x)的区间的区间和。
那么在二维树状数组中,可以类似地定义
s
u
m
[
x
]
[
y
]
sum[x][y]
sum[x][y]记录的是右下角为
(
x
,
y
)
(x, y)
(x,y),高为
l
o
w
b
i
t
(
x
)
lowbit(x)
lowbit(x), 宽为
l
o
w
b
i
t
(
y
)
lowbit(y)
lowbit(y)的区间的区间和。
4.1 单点修改 + 区间查询
代码实现如下:LOJ #133. 二维树状数组 1:单点修改,区间查询
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=(1<<12)+5;
int n,m,op;
ll sum[MAXN][MAXN];
inline ll read()
{
ll X=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+(ch^'0'); ch=getchar();}
return X*f;
}
inline int lowbit(int x)
{
return x&(-x);
}
inline void add(int x,int y,ll k)
{
int tmp_y=y;
while(x<=n)
{
y=tmp_y;
while(y<=m)
sum[x][y]+=k,y+=lowbit(y);
x+=lowbit(x);
}
}
inline ll ask(int x,int y)
{
int tmp_y=y; ll res=0;
while(x)
{
y=tmp_y;
while(y)
res+=sum[x][y],y-=lowbit(y);
x-=lowbit(x);
}
return res;
}
inline ll askRange(int xa,int ya,int xb,int yb)
{
return ask(xb,yb)-ask(xa-1,yb)-ask(xb,ya-1)+ask(xa-1,ya-1);
}
int main()
{
freopen("input.txt","r",stdin);
n=read(); m=read();
while(scanf("%d",&op)!=EOF)
{
if(op==1)
{
int x=read(),y=read(); ll k=read();
add(x,y,k);
}
else if(op==2)
{
int xa=read(),ya=read(),xb=read(),yb=read();
printf("%lld\n",askRange(xa,ya,xb,yb));
}
}
return 0;
}
4.2 区间修改 + 单点查询
我们对于一维数组进行差分,是为了使差分数组前缀和等于原数组对应位置的元素。
那么如何对二维数组进行差分呢?可以针对二维前缀和的求法来设计方案。
二维前缀和:
s
u
m
[
i
]
[
j
]
=
s
u
m
[
i
−
1
]
[
j
]
+
s
u
m
[
i
]
[
j
−
1
]
−
s
u
m
[
i
−
1
]
[
j
−
1
]
+
a
[
i
]
[
j
]
sum[i][j]=sum[i−1][j]+sum[i][j−1]−sum[i−1][j−1]+a[i][j]
sum[i][j]=sum[i−1][j]+sum[i][j−1]−sum[i−1][j−1]+a[i][j]
那么我们可以令差分数组 d [ i ] [ j ] d[i][j] d[i][j] 表示 a [ i ] [ j ] a[i][j] a[i][j] 与 a [ i − 1 ] [ j ] + a [ i ] [ j − 1 ] − a [ i − 1 ] [ j − 1 ] a[i−1][j]+a[i][j−1]−a[i−1][j−1] a[i−1][j]+a[i][j−1]−a[i−1][j−1] 的差。
例如下面这个矩阵
1 1 1 | 4 4 4 | 8 8 8 |
---|---|---|
6 6 6 | 7 7 7 | 2 2 2 |
3 3 3 | 9 9 9 | 5 5 5 |
对应的差分数组就是
1 1 1 | 3 3 3 | 4 4 4 |
---|---|---|
5 5 5 | − 2 -2 −2 | − 9 -9 −9 |
− 3 -3 −3 | 5 5 5 | 1 1 1 |
当我们想要将一个矩阵加上
x
x
x时,怎么做呢?
下面是给最中间的
3
×
3
3\times 3
3×3矩阵加上
x
x
x时,差分数组的变化:
0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 |
---|---|---|---|---|
0 0 0 | + x +x +x | 0 0 0 | 0 0 0 | − x -x −x |
0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 |
0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 |
0 0 0 | − x -x −x | 0 0 0 | 0 0 0 | − x -x −x |
这样给修改差分,造成的效果就是:
0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 |
---|---|---|---|---|
0 0 0 | x x x | x x x | x x x | 0 0 0 |
0 0 0 | x x x | x x x | x x x | 0 0 0 |
0 0 0 | x x x | x x x | x x x | 0 0 0 |
0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 | 0 0 0 |
代码实现如下:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll MAXN=(1<<12)+5;
ll n,m,op;
ll sum[MAXN][MAXN];
inline ll read()
{
ll X=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+(ch^'0'); ch=getchar();}
return X*f;
}
inline ll lowbit(ll x)
{
return x&(-x);
}
inline void add(ll x,ll y,ll k)
{
ll tmp_y=y;
while(x<=n)
{
y=tmp_y;
while(y<=m)
sum[x][y]+=k,y+=lowbit(y);
x+=lowbit(x);
}
}
inline void addRange(ll xa,ll ya,ll xb,ll yb,ll k)
{
add(xa,ya,k);
add(xa,yb+1,-k);
add(xb+1,ya,-k);
add(xb+1,yb+1,k);
}
inline ll ask(ll x,ll y)
{
ll tmp_y=y; ll res=0;
while(x)
{
y=tmp_y;
while(y)
res+=sum[x][y],y-=lowbit(y);
x-=lowbit(x);
}
return res;
}
int main()
{
freopen("input.txt","r",stdin);
freopen("myans.out","w",stdout);
n=read(); m=read();
while(scanf("%lld",&op)!=EOF)
{
if(op==1)
{
ll xa=read(),ya=read(),xb=read(),yb=read(),k=read();
addRange(xa,ya,xb,yb,k);
}
else if(op==2)
{
ll x=read(),y=read();
printf("%lld\n",ask(x,y));
}
}
return 0;
}
4.3 区间修改 + 区间查询
类比之前一维数组的区间修改区间查询,下面这个式子表示的是点 ( x , y ) (x, y) (x,y)的二维前缀和:
∑
i
=
1
x
∑
j
=
1
y
a
[
i
]
[
j
]
=
∑
i
=
1
x
∑
j
=
1
y
∑
k
=
1
i
∑
h
=
1
j
d
[
h
]
[
k
]
\sum_{i=1}^{x}\sum_{j=1}^{y}a[i][j]=\sum_{i=1}^{x}\sum_{j=1}^{y}\sum_{k=1}^{i}\sum_{h=1}^{j}d[h][k]
i=1∑xj=1∑ya[i][j]=i=1∑xj=1∑yk=1∑ih=1∑jd[h][k]
其中,
d
[
h
]
[
k
]
d[h][k]
d[h][k]为点
(
h
,
k
)
(h, k)
(h,k)对应的“二维差分”(同上题)
这个式子非常复杂( O ( n 4 ) O(n^4) O(n4)复杂度,我打暴力才 O ( n 2 ) O(n^2) O(n2)!),但利用树状数组,我们可以把它优化到 O ( ( l o g n ) 2 ) O((log\ n)^2) O((log n)2)!
首先,类比一维数组,统计一下每个 d [ h ] [ k ] d[h][k] d[h][k]出现过多少次。 d [ 1 ] [ 1 ] d[1][1] d[1][1]出现了 x × y x\times y x×y次, d [ 1 ] [ 2 ] d[1][2] d[1][2]出现了 x × ( y − 1 ) x\times (y−1) x×(y−1)次…… d [ h ] [ k ] d[h][k] d[h][k] 出现了 ( x − h + 1 ) × ( y − k + 1 ) (x−h+1)\times (y−k+1) (x−h+1)×(y−k+1)次。
那么这个式子就可以写成:
∑ i = 1 x ∑ j = 1 y d [ i ] [ j ] × ( x + 1 − i ) × ( y + 1 − j ) \sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j]\times (x+1−i)\times (y+1−j) i=1∑xj=1∑yd[i][j]×(x+1−i)×(y+1−j)
把这个式子展开,就得到:
( x + 1 ) × ( y + 1 ) × ∑ i = 1 x ∑ j = 1 y d [ i ] [ j ] − ( y + 1 ) × ∑ i = 1 x ∑ j = 1 y d [ i ] [ j ] × i − ( x + 1 ) × ∑ i = 1 x ∑ j = 1 y d [ i ] [ j ] × j + ∑ i = 1 x ∑ j = 1 y d [ i ] [ j ] × i × j (x+1)\times (y+1)\times \sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j]\\ −(y+1)\times \sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j]\times i\\ −(x+1)\times \sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j]\times j\\ +\sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j]\times i\times j (x+1)×(y+1)×i=1∑xj=1∑yd[i][j]−(y+1)×i=1∑xj=1∑yd[i][j]×i−(x+1)×i=1∑xj=1∑yd[i][j]×j+i=1∑xj=1∑yd[i][j]×i×j
那么我们要开四个树状数组,分别维护:
d [ i ] [ j ] , d [ i ] [ j ] × i , d [ i ] [ j ] × j , d [ i ] [ j ] × i × j d[i][j],d[i][j]\times i,d[i][j]\times j,d[i][j]\times i\times j d[i][j],d[i][j]×i,d[i][j]×j,d[i][j]×i×j
这样就完成了!
代码实现如下:LOJ #135. 二维树状数组 3:区间修改,区间查询
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=2050;
int n,m,op;
ll sum1[MAXN][MAXN],sum2[MAXN][MAXN],sum3[MAXN][MAXN],sum4[MAXN][MAXN];
inline ll read()
{
ll X=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+(ch^'0'); ch=getchar();}
return X*f;
}
inline int lowbit(int x)
{
return x&(-x);
}
inline void add(int x,int y,ll k)
{
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
{
sum1[i][j]+=k;
sum2[i][j]+=k*x;
sum3[i][j]+=k*y;
sum4[i][j]+=k*x*y;
}
}
inline void addRange(int xa,int ya,int xb,int yb,ll k)
{
add(xa,ya,k);
add(xa,yb+1,-k);
add(xb+1,ya,-k);
add(xb+1,yb+1,k);
}
inline ll ask(int x,int y)
{
ll res=0;
for(int i=x;i;i-=lowbit(i))
for(int j=y;j;j-=lowbit(j))
{
res+=(x+1)*(y+1)*sum1[i][j];
res-=(y+1)*sum2[i][j];
res-=(x+1)*sum3[i][j];
res+=sum4[i][j];
}
return res;
}
inline ll askRange(int xa,int ya,int xb,int yb)
{
return ask(xb,yb)-ask(xb,ya-1)-ask(xa-1,yb)+ask(xa-1,ya-1);
}
int main()
{
freopen("input.txt","r",stdin);
freopen("output.out","w",stdout);
n=read(); m=read();
while(scanf("%d",&op)!=EOF)
{
int xa=read(),ya=read(),xb=read(),yb=read(); ll k;
if(op==1) k=read(),addRange(xa,ya,xb,yb,k);
else if(op==2) printf("%lld\n",askRange(xa,ya,xb,yb));
}
return 0;
}
例题讲解
例题1 洛谷 P4054 [JSOI2009]计数问题
容易发现这是个比较简单的二维树状数组。
c , n , m c,n,m c,n,m都不是很大,所以,我们可以对每一个 c c c开一个二维树状数组,每次修改时对该点原来的数字进行 − 1 -1 −1,再对现在的值进行 + 1 +1 +1。
#include <bits/stdc++.h>
using namespace std;
const int MAXN=305;
const int MAXM=105;
int n,m;
int sum[MAXN][MAXN][MAXM],a[MAXN][MAXN];
inline int read()
{
int X=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+(ch^'0'); ch=getchar();}
return X*f;
}
inline int lowbit(int x)
{
return x&(-x);
}
inline void add(int x,int y,int z,int k)
{
int tmp_y=y;
while(x<=n)
{
y=tmp_y;
while(y<=m)
sum[x][y][z]+=k,y+=lowbit(y);
x+=lowbit(x);
}
}
inline int ask(int x,int y,int z)
{
int tmp_y=y,res=0;
while(x)
{
y=tmp_y;
while(y)
res+=sum[x][y][z],y-=lowbit(y);
x-=lowbit(x);
}
return res;
}
inline int askRange(int xa,int ya,int xb,int yb,int z)
{
return ask(xb,yb,z)-ask(xa-1,yb,z)-ask(xb,ya-1,z)+ask(xa-1,ya-1,z);
}
int main()
{
// freopen("input.txt","r",stdin);
n=read(); m=read();
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j)
a[i][j]=read(),add(i,j,a[i][j],1);
int q=read(); while(q--)
{
int op=read();
if(op==1)
{
int x=read(),y=read(),k=read();
add(x,y,a[x][y],-1);
a[x][y]=k;
add(x,y,a[x][y],1);
}
else if(op==2)
{
int xa=read(),xb=read(),ya=read(),yb=read(),z=read();
printf("%d\n",askRange(xa,ya,xb,yb,z));
}
}
return 0;
}
例题2 洛谷 P1972 [SDOI2009]
对于若干个询问的区间 [ l , r ] [l,r] [l,r],如果他们的 r r r都相等的话,那么项链中出现的同一个数字,一定是只关心出现在最右边的那一个的。
我们记录下每一个询问 q u e [ i ] que[i] que[i],按 r r r从小到大排序, u s e d [ i ] used[i] used[i]表示数字 i i i出现的最右边的位置,树状数组 s u m [ i ] sum[i] sum[i]维护从 1 1 1到 i i i不同数字的个数有多少个。从左到右扫描,对于第 i i i个数 a [ i ] a[i] a[i],如果 a [ i ] a[i] a[i]之前打过标记,在之前的位置 u s e d [ a [ i ] ] used[a[i]] used[a[i]]加上 − 1 -1 −1,保证无重复。然后在 i i i位置上加上 1 1 1。对于第 i i i个询问 q u e [ i ] que[i] que[i],我们直接查询区间 [ l , r ] [l,r] [l,r]的和即可。
代码实现如下:
#include <bits/stdc++.h>
using namespace std;
const int MAXN=1e6+5;
int n,m;
int a[MAXN],used[MAXN],sum[MAXN],ans[MAXN];
struct Query
{
int l,r;
int id;
bool operator <(const Query &cir) const
{
return r<cir.r;
}
}que[MAXN];
inline int read()
{
int X=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+(ch^'0'); ch=getchar();}
return X*f;
}
inline int lowbit(int x)
{
return x&(-x);
}
inline void add(int p,int x)
{
while(p<=n) sum[p]+=x,p+=lowbit(p);
}
inline int ask(int p)
{
int res=0;
while(p)
res+=sum[p],p-=lowbit(p);
return res;
}
inline int askRange(int l,int r)
{
return ask(r)-ask(l-1);
}
void readdata()
{
n=read();
for(int i=1;i<=n;++i)
a[i]=read();
m=read();
for(int i=1;i<=m;++i)
que[i].l=read(),que[i].r=read(),que[i].id=i;
sort(que+1,que+m+1);
}
void work()
{
int pre=1;
for(int i=1;i<=m;++i)
{
for(int j=pre;j<=que[i].r;++j)
{
if(used[a[j]]) add(used[a[j]],-1);
add(j,1);
used[a[j]]=j;
}
pre=que[i].r+1;
ans[que[i].id]=askRange(que[i].l,que[i].r);
}
for(int i=1;i<=m;++i)
printf("%d\n",ans[i]);
}
int main()
{
//freopen("input.txt","r",stdin);
readdata();
work();
return 0;
}