1. Segment Tree简介
Segment Tree,中文名分段树,也算是一个经典的编程算法了,它最常用的场景就是在一些固定但是经常进行值更新的array当中进行区间范围内的某些特征的查询,比如最大值,最小值,区间和等等。
具体来说,对于一个长度固定的array,我们可以对其内部的值进行更新操作,然后也可以对其中某一范围区间的值进行query,比如求和或者求最大值或者最小值。
我们以求和为例,单独对值进行更新或者对范围内求和,其实都可以在 O ( 1 ) O(1) O(1)的复杂度内实现,但是,两者不可兼容。如果我们想要在 O ( 1 ) O(1) O(1)的时间复杂度内更新值,那么范围内求和就会是一个 O ( N ) O(N) O(N)时间复杂度的操作;反之,如果我们使用累积数组,那么我们可以在 O ( 1 ) O(1) O(1)时间复杂度内求取范围和,但是此时更新每一个值就必须同步更新累积数组,就会变成一个 O ( N ) O(N) O(N)复杂度的操作。
因此,如果我们需要同样频繁地更新值以及对范围特征进行query,那么上述方法在时间复杂度上就不可接受了。
而Segment Tree的一个核心思路就是说借用二叉树的结构,将array的每一段区间都保存到二叉树的某一段节点当中,此时,我们无论是对于值的更新还是对于范围的query都可以在 O ( l o g N ) O(logN) O(logN)的时间复杂度内实现,从而达到整体计算效率上的提升。
下图就是网上找的一张典型的最小值求解的Segment Tree的示例图。
下面,我们就来具体看一下Segment Tree的具体算法实现,即其究竟是如何在 O ( l o g N ) O(logN) O(logN)的时间复杂度内实现值的更新以及范围内特征的query。
2. Segment Tree算法实现
1. 原理说明
如上图所示,Segment Tree的主体就是一个二叉树,其每一个节点都代表着某一个区间范围内的元素性质,然后其左右节点都是对父节点所表示的区间的一个对分,而树的叶子节点就是具体的array当中的某一个具体的值。
因此,每一次对array当中的某一个具体位置上的元素的更新就是一个二分搜索,因此对某一个值的检索的时间复杂度就是 O ( l o g N ) O(logN) O(logN)。
但是,每一次对某一个具体的值进行改动,就会相应地影响到其所在的区间的特征值,因此我们需要同步地修改其所在的区间的特征值,我们可以从叶子节点向上追寻到根节点,因此也是一个 O ( l o g N ) O(logN) O(logN)时间复杂度的操作。
最后,我们考察如何在一个范围内求取某一个特征。由于任何范围都可以拆分上述segment tree当中某几个节点的组合,因此,我们只需要迭代找到这些节点然后组合在一起就能够获得我们所需的答案了。
2. vanilla代码实现
我们给出python的segment tree的伪代码实现如下:
class SegmentTreeNode:
def __init__(self, val, lbound, rbound, lchild=None, rchild=None):
self.val = val
self.lbound = lbound
self.rbound = rbound
self.lchild = lchild
self.rchild = rchild
class SegmentTree:
def __init__(self, arr):
self.length = len(arr)
self.root = self.build(0, self.length-1, arr)
self.vals = arr
def feature_func(self, lval, rval):
# get the target feature, such as sum, min or max.
raise NotImplementError()
def build(self, lbound, rbound, arr):
if lbound == rbound:
root = SegmentTreeNode(arr[lbound], lbound, rbound)
else:
mid = (lbound+rbound) // 2
lchild = self.build(lbound, mid, arr)
rchild = self.build(mid+1, rbound, arr)
val = self.feature_func(lchild.val, rchild.val)
root = SegmentTreeNode(val, lbound, rbound, lchild, rchild)
return root
def update(self, idx, val):
self.vals[idx] = val
self._update(idx, val, self.root)
return
def _update(self, idx, val, root):
if root.lbound == root.rbound:
assert(root.lbound == idx)
root.val = val
return
mid = (root.lbound + root.rbound) // 2
if idx <= mid:
self._update(idx, val, root.lchild)
else:
self._update(idx, val, root.rchild)
root.val = self.feature_func(root.lchild.val, root.rchild.val)
return
def query(self, lb, rb):
return self._query(lb, rb, self.root)
def _query(self, lb, rb, root):
if lb == root.lbound and rb == root.rbound:
return root.val
mid = (root.lbound+root.rbound) // 2
if rb <= mid:
return self._query(lb, rb, root.lchild)
elif lb > mid:
return self._query(lb, rb, root.rchild)
else:
lval = self._query(lb, mid, root.lchild)
rval = self._query(mid+1, rb, root.rchild)
return self.feature_func(lval, rval)
对于不同的任务,我们只需要相应地修改对应的feature_func
即可。
一些典型的case如下:
-
求范围内最大值
def feature_func(self, lval, rval): return max(lval, rval)
-
求范围内最小值
def feature_func(self, lval, rval): return min(lval, rval)
-
求范围内元素之和
def feature_func(self, lval, rval): return lval + rval
3. 优化设计(一)
另一方面,又因为事实上任何二叉树都可以用数组进行表达,因此,事实上我们也可以对上述代码实现进行优化。
class SegmentTree:
def __init__(self, arr):
self.length = len(arr)
self.tree = [0 for _ in range(4 * self.length)]
self.vals = deepcopy(arr)
self.build(1, arr, 0, self.length-1)
def feature_func(self, lval, rval):
return lval + rval
def build(self, node, arr, lb, rb):
if lb == rb:
self.tree[node] = arr[lb]
else:
mid = (lb + rb) // 2
lval = self.build(2*node, arr, lb, mid)
rval = self.build(2*node+1, arr, mid+1, rb)
self.tree[node] = self.feature_func(lval, rval)
return self.tree[node]
def _update(self, idx, val, node, lb, rb):
if lb == rb:
assert(lb == idx)
self.tree[node] = val
else:
mid = (lb + rb) // 2
if idx <= mid:
self._update(idx, val, 2*node, lb, mid)
else:
self._update(idx, val, 2*node+1, mid+1, rb)
self.tree[node] = self.feature_func(self.tree[2*node], self.tree[2*node+1])
return
def update(self, idx, val):
self.vals[idx] = val
self._update(idx, val, 1, 0, self.length-1)
return
def _query(self, left, right, node, lb, rb):
if left == lb and right == rb:
return self.tree[node]
mid = (lb + rb) // 2
if right <= mid:
return self._query(left, right, 2*node, lb, mid)
elif left > mid:
return self._query(left, right, 2*node+1, mid+1, rb)
else:
lval = self._query(left, mid, 2*node, lb, mid)
rval = self._query(mid+1, right, 2*node+1, mid+1, rb)
return self.feature_func(lval, rval)
def query(self, lb, rb):
return self._query(lb, rb, 1, 0, self.length-1)
同样的,给出一些典型的segment tree特征函数如下:
-
求范围内最大值
def feature_func(self, lval, rval): return max(lval, rval)
-
求范围内最小值
def feature_func(self, lval, rval): return min(lval, rval)
-
求范围内元素之和
def feature_func(self, lval, rval): return lval + rval
但是需要注意的是,虽然原则上任意一棵包含 n n n个叶子节点二叉树事实上只需要 2 n − 1 2n-1 2n−1个节点即可表达,但是由于这里的二叉树并不总是完全和二叉树,因此事实上我们需要一些冗余节点来确保所有的节点都能被存储下来,我们事实上需要至多 4 n 4n 4n个节点来进行树节点的存储。
这会导致一部分的性能损失和空间浪费,因此,我们可以更进一步地对上述代码进行优化。
4. 优化设计(二)
给出优化后的python代码实现如下:
class SegmentTree:
def __init__(self, arr):
self.length = len(arr)
self.tree = self.build(arr)
def feature_func(self, *args):
# get the target feature, such as sum, min or max.
raise NotImplementError()
def build(self, arr):
n = len(arr)
tree = [0 for _ in range(2*n)]
for i in range(n):
tree[i+n] = arr[i]
for i in range(n-1, 0, -1):
tree[i] = self.feature_func(tree[2*i], tree[2*i+1])
return tree
def update(self, idx, val):
idx = idx + self.length
self.tree[idx] = val
while idx > 1:
self.tree[idx // 2] = self.feature_func(self.tree[idx], self.tree[idx ^ 1])
idx = idx // 2
return
def query(self, lb, rb):
lb += self.length
rb += self.length
nodes = []
while lb < rb:
if lb % 2 == 1:
nodes.append(self.tree[lb])
lb += 1
if rb % 2 == 0:
nodes.append(self.tree[rb])
rb -= 1
lb = lb // 2
rb = rb // 2
if lb == rb:
nodes.append(self.tree[rb])
return self.feature_func(*nodes)
同样的,给出一些典型的segment tree特征函数如下:
-
求范围内最大值
def feature_func(self, *args): return max(args)
-
求范围内最小值
def feature_func(self, *args): return min(args)
-
求范围内元素之和
def feature_func(self, *args): return sum(args)
当然,网上更为常见的实现方式是使用位运算的方式进行实现,具体来说:
class SegmentTree:
def __init__(self, arr):
self.length = len(arr)
self.tree = self.build(arr)
def feature_func(self, *args):
# get the target feature, such as sum, min or max.
raise NotImplementError()
def build(self, arr):
n = len(arr)
tree = [0 for _ in range(2*n)]
for i in range(n):
tree[i+n] = arr[i]
for i in range(n-1, 0, -1):
tree[i] = self.feature_func(tree[i<<1], tree[(i<<1) | 1])
return tree
def update(self, idx, val):
idx = idx + self.length
self.tree[idx] = val
while idx > 1:
self.tree[idx>>1] = self.feature_func(self.tree[idx], self.tree[idx ^ 1])
idx = idx>>1
return
def query(self, lb, rb):
lb += self.length
rb += self.length
nodes = []
while lb < rb:
if lb & 1 == 1:
nodes.append(self.tree[lb])
lb += 1
if rb & 1 == 0:
nodes.append(self.tree[rb])
rb -= 1
lb = lb >> 1
rb = rb >> 1
if lb == rb:
nodes.append(self.tree[rb])
return self.feature_func(*nodes)
3. 例题考察
1. Leetcode 2659
题目链接:
1. 解题思路
这一题思路上其实还行,我们总是依次删除元素的,因此只需要对元素排个序然后取出对应元素的index就能知道每次删除元素时需要移动的index距离,而真实的移动次数就是这两个index之间当前剩余的元素个数减一。
因此,我们只需要用一个segment tree来进行范围求和处理即可。
2. 代码实现
给出python代码实现如下:
class SegmentTree:
def __init__(self, arr):
self.length = len(arr)
self.tree = self.build(arr)
def feature_func(self, *args):
return sum(args)
def build(self, arr):
n = len(arr)
tree = [0 for _ in range(2*n)]
for i in range(n):
tree[i+n] = arr[i]
for i in range(n-1, 0, -1):
tree[i] = self.feature_func(tree[i<<1], tree[(i<<1) | 1])
return tree
def update(self, idx, val):
idx = idx + self.length
self.tree[idx] = val
while idx > 1:
self.tree[idx>>1] = self.feature_func(self.tree[idx], self.tree[idx ^ 1])
idx = idx>>1
return
def query(self, lb, rb):
lb += self.length
rb += self.length
nodes = []
while lb < rb:
if lb & 1 == 1:
nodes.append(self.tree[lb])
lb += 1
if rb & 1 == 0:
nodes.append(self.tree[rb])
rb -= 1
lb = lb >> 1
rb = rb >> 1
if lb == rb:
nodes.append(self.tree[rb])
return self.feature_func(*nodes)
class Solution:
def countOperationsToEmptyArray(self, nums: List[int]) -> int:
n = len(nums)
index = [i for i in range(n)]
index = sorted(index, key=lambda x: nums[x])
status = [1 for _ in range(n)]
segment_tree = SegmentTree(status)
prev = 0
res = 0
for idx in index:
if idx >= prev:
res += segment_tree.query(prev, idx) - 1
else:
res += segment_tree.query(0, idx) + segment_tree.query(prev, n-1) - 1
prev = idx
segment_tree.update(idx, 0)
return res + n
提交代码评测得到:耗时6159ms,占用内存31.2MB。