今天介绍一种"高级"的数据结构——树状数组。
问题引出
如果我们需要对数组a进行以下两个操作,如何才能使得效率尽可能高?
- 单点修改:a[i]+=k
- 区间查询:求a[l:r]的和,即sum(a[l], a[l+1], … , a[r])
(1). 维护数组a
单点修改的复杂度为O(1),而区间查询的复杂度为O(n)。
(2). 维护数组a的前缀和
可以在O(1)的时间内完成区间查询,而单点修改却需要O(n),因为a[i]之后的所有数的前缀和都包含a[i],所以需要全部修改。
(3). 维护树状数组
考虑(1)和(2),前者每个位置只与原数组中的一个数字有关,因此可以很快进行修改,但区间操作则很慢;后者位置i存储的值与原数组中的前i个值均有关,因此可以很快完成区间查询,但由于每个位置牵涉众多,修改一个值需要动很多位置。
那么能否找到一种折中的办法?
我们希望维护的数组中每个位置存储的值与原数组中的一部分有关,这样可以加快区间查询,同时这"一部分"又不会太多,从而使得单点修改也不用改变太多的值。
树状数组的操作
树状数组很好的解决了这一问题,记原数组为a,我们需要维护的数组为b(为方便起见,我们假设数组是从1开始的):
b[1]=a[1]
b[2]=a[1]+a[2]
b[3]=a[3]
b[4]=a[1]+a[2]+a[3]+a[4]
b[5]=a[5]
b[6]=a[5]+a[6]
b[7]=a[7]
b[8]=a[1]+a[2]+a[3]+a[4]+a[5]+a[6]+a[7]+a[8]
可以看出,b中每个位置存储的值与a中的一部分数字有关,b[i]控制的最后一个数是a[i],且它的控制范围取决于i的二进制表示的最后一个1。比如6的二进制为110,最后一个1表示的是2,因此b[6]可控制两个数:a[5]、a[6],而8的二进制表示为1000,最后的一个1表示8,所以其控制8个数。
这样一来,我们便可以只修改部分值完成单点修改和区间查询:
- 如果我们需要修改a[3],那么我们需要修改与其相关的b[3]、b[4]、b[8]
- 如果我们需要求前7项和,实际需要计算b[7]+b[6]+b[4]
lowbit计算
为了方便得实现树状数组,我们需要获取一个二进制表示的最后一个1,这可以用x & (-x)完成。
假设x的末尾共有m个0,则x可表示为
y
⨁
1
⨁
m
∗
0
y\bigoplus 1\bigoplus m*0
y⨁1⨁m∗0,x的反码(将0与1调换)表示为
y
‾
⨁
0
⨁
m
∗
1
\overline{y}\bigoplus 0\bigoplus m*1
y⨁0⨁m∗1,则-x的二进制表示为(将x全部取反码后+1)
y
‾
⨁
1
⨁
m
∗
0
\overline{y}\bigoplus 1\bigoplus m*0
y⨁1⨁m∗0,因此
x
&
(
−
x
)
x \& (-x)
x&(−x) 的值为
1
⨁
m
∗
0
1\bigoplus m*0
1⨁m∗0,就是我们想要的二进制表示的最后一个1。
def lowbit(x):
return x & (-x)
复杂度分析
单点修改
对于a[i]的修改,我们需要找到b中哪些数值包含了它。
对于位置x,它的管辖范围是
x
&
(
−
x
)
x\& (-x)
x&(−x)到x,如果
x
&
(
−
x
)
<
i
<
x
x\& (-x) < i < x
x&(−x)<i<x,则b[x]中包含了a[i],即x比i大,但其去掉最后一位后比i小。所以x除了末尾的0与最后一个1以外,前面部分应与i的二进制表示相同,所以x的数量最多不会超过n的二进制长度。
区间查询
若要计算a的前i项和,只需要考虑i的二进制表示中有多少个1即可,复杂度也不会超过二进制中1的个数。
前1101010项和=b[1101010]+b[1101000]+b[1100000]+b[1000000]
因此使用树状数组完成这两部分的复杂度均不超过O(log n)
其他变形
区间修改+单点查询
可以将数组改变为差分数组,则原数组的区间修改可看作对差分数组的两次单点修改,而对原数组的单点查询可转化为对差分数组的区间查询,从而解决问题。
区间修改+区间查询
需要维护两个数组的前缀和,可参考“高级”数据结构——树状数组!
模板
class BIT:
def __init__(self, n):
self.n = n
self.a = [0] * (n + 1)
@staticmethod
def lowbit(x):
return x & (-x)
def query(self, x):
ret = 0
while x > 0:
ret += self.a[x]
x -= BIT.lowbit(x)
return ret
def update(self, x, k):
while x <= self.n:
self.a[x] += k
x += BIT.lowbit(x)