最近想把优先队列的实现方式都撸一遍,因此出个系列整理。
[算法]优先队列之配对堆
[算法]优先队列之左偏树
各种堆的复杂度分析
方法\堆类型 | 二叉 | 配对 | 斐波那契 | 左偏 | 二顶 |
---|---|---|---|---|---|
查找最小 | 1 | 1 | 1 | 1 | log(n) |
弹出最小 | log(n) | log(n) | log(n) | log(n) | log(n) |
修改(降key) | log(n) | log(n)- | 1 | log(n) | log(n) |
插入 | log(n) | 1 | 1 | log(n) | 1 |
归并 | n | 1 | 1 | log(n) | log(n) |
可持久化 | No | No |
左偏树
配对堆的核心要义在于合并, 核心API只有下文中的:_merge
,关键在于保证树的偏向性。
参考此篇讲解:https://oi-wiki.org/ds/leftist-tree/ (PS:原文中的代码有较多bug,基本可将其当成伪码处理)
Python实现
class Item:
def __init__(self, priority=0, value=None):
self.priority = priority
self.value = value
self.left = self.right = self.fa = None
self.dist = 1 # 到外节点的距离
def clean(self):
self.left = self.right = self.fa = None
self.dist = 1
class PriorityQueue: # 最小优先队列 | 左偏树
def __init__(self):
self.root = None
self.nodes_num = 0
def push(self, item: Item): # O(logN)
if self.root is None:
self.root = item
else:
self.root = self._merge(self.root, item)
self.nodes_num += 1
def poll(self) -> Item:
tmp = self.root
if self.size() > 0:
self.root = self._merge(self.root.left, self.root.right) # 合并根的左右儿子即可
self.nodes_num -= 1
tmp.clean() # 清除弹出节点的父亲和左右指针和dist(非必要,仅在某些[引用]场景下是必要的)
return tmp
def remove(self, target: Item) -> Item: # 左偏树特有操作
def push_up(x: Item):
if x is not None:
if self._cal_dist(x.right) > self._cal_dist(x.left): # 保证左偏性质,即左儿子的dist要大于右儿子的dist
x.left, x.right = x.right, x.left
if self._cal_dist(x) != self._cal_dist(x.right) + 1:
x.dist = self._cal_dist(x.right) + 1 # 显然外节点会在右儿子处
push_up(x.fa)
if target:
chs = self._merge(target.left, target.right) # 基于_push_up即可满足要求
if target.fa.left == target: # 是左儿子
target.fa.left = chs
else:
target.fa.right = chs
self._build_fa(chs, target.fa) # 重构父节点
push_up(chs)
target.clean()
return target
def decrease_priority(self, offset, target: Item): # 要减少的优先值和对应的元素 O(logN)-
assert offset > 0 and target
target.priority -= offset
if target != self.root: # 非根节点
self.remove(target)
self.root = self._merge(self.root, target) # target弹出并与根合并
def merge(self, new: 'PriorityQueue'): # 合并、建树 O(logN) + O(logM)
self.root = self._merge(self.root, new.root)
self.nodes_num += new.nodes_num
def min(self) -> Item:
return self.root
def size(self):
return self.nodes_num
def _merge(self, x: Item, y: Item): # 核心API:O(logN) + O(logM), 每次merge根dist都会-1
if x is not None and y is not None:
if x.priority > y.priority: # 保证x的优先级较小
x, y = y, x
x.right = self._merge(x.right, y) # 右儿子进行合并,保证堆的性质
self._build_fa(x.right, x) # ☆ 仅对 decrease_priority 有用,若不需要可以删除
# 保证左偏性质,即左儿子的dist要大于右儿子的dist
if self._cal_dist(x.right) > self._cal_dist(x.left): # PS:0.5概率决定左右儿子也是O(logN)
x.left, x.right = x.right, x.left
x.dist = self._cal_dist(x.right) + 1 # 显然外节点会在右儿子处,基于右节点更新
return x
else:
return x if x is not None else y
@staticmethod
def _cal_dist(item: Item):
return item.dist if item is not None else 0 # 空节点的dist返回0
@staticmethod
def _build_fa(ch: Item, fa: Item):
if ch is not None:
ch.fa = fa
测试模块
def test():
data = [Item(i, str(i)) for i in range(5)]
pq = PriorityQueue()
for i in range(3):
pq.push(data[i])
assert pq.poll().value == '0'
assert pq.poll().value == '1'
assert pq.min().value == '2'
pq.push(data[0])
for i in range(3, 5):
pq.push(data[i])
assert pq.min().value == '0'
ans = []
while pq.size():
ans.append(pq.poll().value)
assert ''.join(ans) == '0234'
def test_merge():
data = [Item(i, str(i)) for i in range(5)]
pq = PriorityQueue()
pq2 = PriorityQueue()
pq2.push(data[3])
pq2.push(data[4])
pq.push(data[2])
pq.push(data[0])
pq.push(data[1])
pq2.push(data[0])
assert pq.poll().value == '0'
pq.merge(pq2)
ans = []
while pq.size():
ans.append(pq.poll().value)
assert ''.join(ans) == '01234'
def test_decrease():
for k in range(5):
data = [Item(i, str(i)) for i in range(5)]
pq = PriorityQueue()
for i in range(5):
pq.push(data[i])
assert pq.min().value == '0'
pq.decrease_priority(5, data[k])
ans = []
while pq.size():
ans.append(pq.poll().value)
ans_str = [str(i) for i in range(5)]
ans_str.remove(str(k))
assert ''.join(ans) == str(k) + ''.join(ans_str)