抱佛脚Day05
树状数组(Binary Indexed Tree, BIT)是能够完成下述操作的数据结构。
给一个初始值全为0的数列
a
1
,
a
2
,
⋯
,
a
i
a_1, a_2,\cdots,a_i
a1,a2,⋯,ai
- 给定 i i i,计算 a 1 + a 2 + ⋯ + a i a_1+a_2+\cdots+a_i a1+a2+⋯+ai
- 给定 i i i和 x x x,计算 a i = a i + x a_i = a_i + x ai=ai+x
BIT的结构
BIT使用数组维护下图所示的部分和
也就是把线段树中不需要的节点去掉之后,再把剩下的节点对应到数组中。让我们对比每个节点对应的区间的长度和节点编号的二进制表示。以1结尾的1,3,5,7的长度是1,最后有1个0的2,6的长度是2,最后有2个0的4的长度是4⋯⋯这样,编号的二进制表示就能够和区间非常容易地对应起来。利用这个性质,BIT可以通过非常简单的位运算实现。
BIT的求和
计算前 i i i项的和需要从 i i i开始,不断把当前位置 i i i的值加到结果中,并从 i i i中减去 i i i的二进制最低非0位对应的幂直到 i i i变成0为止。二进制的最后一个1可以通过 i & ( − i ) i \&(-i) i&(−i)到。
BIT的值的更新
使第 i i i项的值增加 x x x要从 i i i开始,不断把当前位置 i i i的值增加 x x x,并把 i i i的二进制最低非0位对应的幂加到 i i i上。
BIT的复杂度
O ( log n ) O(\log n) O(logn)
BIT的实现
注意: i & ( − i ) = i & ( i − 1 ) i\&(-i) = i \&(i - 1) i&(−i)=i&(i−1)
class BinaryIndexedTree(object):
def __init__(self, nums):
self.nums = nums
self.n = len(nums)
self.tree = [0] * (self.n + 1)
for i in range(1, len(self.tree)):
self.tree[i] = nums[i - 1]
index = i - 1
while index > i - self.lowbit(i):
self.tree[i] += self.tree[index]
index -= self.lowbit(index)
def lowbit(self, x):
return x & (-x)
def update(self, index, val):
diff = val - self.nums[index]
self.nums[index] = val
index += 1
while index < self.n + 1:
self.tree[index] += diff
index += self.lowbit(index)
def getSum(self, index):
ans = 0
index += 1
while index >= 1:
ans += self.tree[index]
index -= self.lowbit(index)
return ans
def sumRange(self, left, right):
return self.getSum(right) - self.getSum(left - 1)
if __name__ == "__main__":
nums = [1, 7, 3, 0, 5, 8, 3, 2, 6, 2, 1, 1, 4, 5]
bit = BinaryIndexedTree(nums)
print(bit.tree)
print(bit.getSum(12))
bit.update(4, 2)
print(bit.tree)
print(bit.getSum(12))
print(bit.sumRange(1, 12))
另一种写法(更好理解)
class BinaryIndexedTree(object):
def __init__(self, nums):
self.nums = nums
self.n = len(nums) + 1
self.tree = [0] * self.n
for i in range(1, self.n):
self.add(i, nums[i - 1])
def lowbit(self, x):
return x & (-x)
def add(self, index, val):
while index < self.n:
self.tree[index] += val
index += self.lowbit(index)
def update(self, index, val):
self.add(index + 1, val - self.nums[index])
self.nums[index] = val
def getSum(self, index):
ans = 0
index += 1
while index >= 1:
ans += self.tree[index]
index -= self.lowbit(index)
return ans
def sumRange(self, left, right):
return self.getSum(right) - self.getSum(left - 1)
if __name__ == "__main__":
nums = [1, 7, 3, 0, 5, 8, 3, 2, 6, 2, 1, 1, 4, 5]
bit = BinaryIndexedTree(nums)
print(bit.getSum(12)) # 43
print(bit.sumRange(3, 11)) # 31
bit.update(4, 2)
print(bit.getSum(12)) # 40
print(bit.sumRange(1, 12)) # 39