前言
本篇归纳线段树
线段树
几篇博文归纳的很好
模板
class SegmentTree:
def __init__(self, data, merge):
'''
data:传入的数组
merge:处理的业务逻辑,例如求和/最大值/最小值,lambda表达式
'''
self.data = data
self.n = len(data)
# 申请4倍data长度的空间来存线段树节点
self.tree = [None] * (4 * self.n) # 索引i的左孩子索引为2i+1,右孩子为2i+2
self._merge = merge
if self.n:
self._build(0, 0, self.n-1)
def query(self, ql, qr):
'''
返回区间[ql,..,qr]的值
'''
return self._query(0, 0, self.n-1, ql, qr)
def update(self, index, value):
# 将data数组index位置的值更新为value,然后递归更新线段树中被影响的各节点的值
self.data[index] = value
self._update(0, 0, self.n-1, index)
def _build(self, tree_index, l, r):
'''
递归创建线段树
tree_index : 线段树节点在数组中位置
l, r : 该节点表示的区间的左,右边界
'''
if l == r:
self.tree[tree_index] = self.data[l]
return
mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
left, right = 2 * tree_index + 1, 2 * tree_index + 2 # tree_index的左右子树索引
self._build(left, l, mid)
self._build(right, mid+1, r)
self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])
def _query(self, tree_index, l, r, ql, qr):
'''
递归查询区间[ql,..,qr]的值
tree_index : 某个根节点的索引
l, r : 该节点表示的区间的左右边界
ql, qr: 待查询区间的左右边界
'''
if l == ql and r == qr:
return self.tree[tree_index]
mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
left, right = tree_index * 2 + 1, tree_index * 2 + 2
if qr <= mid:
# 查询区间全在左子树
return self._query(left, l, mid, ql, qr)
elif ql > mid:
# 查询区间全在右子树
return self._query(right, mid+1, r, ql, qr)
# 查询区间一部分在左子树一部分在右子树
return self._merge(self._query(left, l, mid, ql, mid),
self._query(right, mid+1, r, mid+1, qr))
def _update(self, tree_index, l, r, index):
'''
tree_index:某个根节点索引
l, r : 此根节点代表区间的左右边界
index : 更新的值的索引
'''
if l == r == index:
self.tree[tree_index] = self.data[index]
return
mid = (l+r)//2
left, right = 2 * tree_index + 1, 2 * tree_index + 2
if index > mid:
# 要更新的区间在右子树
self._update(right, mid+1, r, index)
else:
# 要更新的区间在左子树index<=mid
self._update(left, l, mid, index)
# 里面的小区间变化了,包裹的大区间也要更新
self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])
例子
区域和检索 - 数组可修改
leet上307题
给你一个数组 nums ,请你完成两类查询,其中一类查询要求更新数组下标对应的值,另一类查询要求返回数组中某个范围内元素的总和。
实现 NumArray 类:
NumArray(int[] nums) 用整数数组 nums 初始化对象
void update(int index, int val) 将 nums[index] 的值更新为 val
int sumRange(int left, int right) 返回子数组 nums[left, right] 的总和(即,nums[left] + nums[left + 1], ..., nums[right])
class NumArray:
def __init__(self, nums: List[int]):
self.segment_tree = SegmentTree(nums, lambda x, y : x + y)
def update(self, i: int, val: int) -> None:
self.segment_tree.update(i, val)
def sumRange(self, i: int, j: int) -> int:
return self.segment_tree.query(i, j)
class SegmentTree:
def __init__(self, data, merge):
'''
data:传入的数组
merge:处理的业务逻辑,例如求和/最大值/最小值,lambda表达式
'''
self.data = data
self.n = len(data)
# 申请4倍data长度的空间来存线段树节点
self.tree = [None] * (4 * self.n) # 索引i的左孩子索引为2i+1,右孩子为2i+2
self._merge = merge
if self.n:
self._build(0, 0, self.n-1)
def query(self, ql, qr):
'''
返回区间[ql,..,qr]的值
'''
return self._query(0, 0, self.n-1, ql, qr)
def update(self, index, value):
# 将data数组index位置的值更新为value,然后递归更新线段树中被影响的各节点的值
self.data[index] = value
self._update(0, 0, self.n-1, index)
def _build(self, tree_index, l, r):
'''
递归创建线段树
tree_index : 线段树节点在数组中位置
l, r : 该节点表示的区间的左,右边界
'''
if l == r:
self.tree[tree_index] = self.data[l]
return
mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
left, right = 2 * tree_index + 1, 2 * tree_index + 2 # tree_index的左右子树索引
self._build(left, l, mid)
self._build(right, mid+1, r)
self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])
def _query(self, tree_index, l, r, ql, qr):
'''
递归查询区间[ql,..,qr]的值
tree_index : 某个根节点的索引
l, r : 该节点表示的区间的左右边界
ql, qr: 待查询区间的左右边界
'''
if l == ql and r == qr:
return self.tree[tree_index]
mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
left, right = tree_index * 2 + 1, tree_index * 2 + 2
if qr <= mid:
# 查询区间全在左子树
return self._query(left, l, mid, ql, qr)
elif ql > mid:
# 查询区间全在右子树
return self._query(right, mid+1, r, ql, qr)
# 查询区间一部分在左子树一部分在右子树
return self._merge(self._query(left, l, mid, ql, mid),
self._query(right, mid+1, r, mid+1, qr))
def _update(self, tree_index, l, r, index):
'''
tree_index:某个根节点索引
l, r : 此根节点代表区间的左右边界
index : 更新的值的索引
'''
if l == r == index:
self.tree[tree_index] = self.data[index]
return
mid = (l+r)//2
left, right = 2 * tree_index + 1, 2 * tree_index + 2
if index > mid:
# 要更新的区间在右子树
self._update(right, mid+1, r, index)
else:
# 要更新的区间在左子树index<=mid
self._update(left, l, mid, index)
# 里面的小区间变化了,包裹的大区间也要更新
self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])
结语
线段树作为一种高级数据结构
现在简单记一笔,回头再看看