《Python3
标准库》笔记:heapq
堆排序算法
堆的概念
堆(heapq)
是一个树形数据结构,其中子节点与父节点有一种有序关系。二叉堆(binary heap)
可以使用一个有组织的列表或数组表示,其中元素N的子元素位于2*
N+1和2*
N+2(索引从0开始)。这种布局允许原地重新组织堆,从而不必在增加或者删除元素时重新分配大量内存。
最大堆(max-heap)
确保父节点大于或等于其两个子节点。最小堆(min-heap)
要求父节点小于或等于其子节点。Python的heapq
模块实现了一个最小堆。
堆的创建
创建堆有两种基本方式:heappush()
和heapify()
。
import heapq
import math
from io import StringIO
data = [19, 9, 4, 10, 11]
def show_tree(tree, total_width=36, fill=' '):
"""Pretty-print a tree"""
output = StringIO()
last_row = -1
for i, n in enumerate(tree):
if i:
row = int(math.floor(math.log(i + 1, 2)))
else:
row = 0
if row != last_row:
output.write('\n')
columns = 2 ** row
col_width = int(math.floor(total_width / columns))
output.write(str(n).center(col_width, fill))
last_row = row
print(output.getvalue())
print('-' * total_width)
print()
heap = []
print('random:', data)
print()
for n in data:
print('add {:>3}:'.format(n))
heapq.heappush(heap, n)
show_tree(heap)
heapq.heapify(data)
print('heapified:')
show_tree(data)
使用heappush()
,从数据源增加新元素时会保持元素的堆排序顺序。如果数据已经在内存中,那么使用heapify()
原地重新组织列表中的元素会更高效。
如果按照堆顺序一次一个元素地构建列表,那么结果与构建一个无序列表再调用heapify()
是一样的。
访问堆的元素
一旦堆被正确地组织,则可以使用heappop()
删除有最小值的元素。
for i in range(2):
smallest = heapq.heappop(data)
print('pop {:>3}:'.format(smallest))
show_tree(data)
如果希望在一个操作中删除现有元素并替换为新值,则可以使用heapreplace()
。
for n in [0, 13]:
smallest = heapq.heapreplace(data, n)
print('replace {:>2} with {:>2}:'.format(smallest, n))
show_tree(data)
通过原地替换元素,这样可以维持一个固定大小的堆,如按优先级排序的作业队列。
堆的数据极值
heapq
还包括两个检查可迭代对象(iterable)
的函数,可以查找其中包含最大或最小值的范围。
import heapq
from heapq_heapdata import data
data = [19, 9, 4, 10, 11]
print('all :', data)
print('3 largest :', heapq.nlargest(3, data))
print('from sort :', list(reversed(sorted(data)[-3:])))
print('3 smallest:', heapq.nsmallest(3, data))
print('from sort :', sorted(data)[:3])
只有当n值(n>1)相对小时使用nlargest()
和nsmallest()
才算高效,不过有些情况下这两个函数会很方便。
高效合并有序序列
对于小数据集,将多个有序序列合并到一个新序列很容易。
list(sorted(itertools.chain(*data)))
对于较大的数据集,这个技术可能会占用大量内存。merge()
不是对整个合并后的序列排序,而是使用一个堆一次一个元素地生成一个新序列,利用固定大小的内存确定下一个元素。
import heapq
import random
random.seed(2016)
data = []
for i in range(4):
new_data = list(random.sample(range(1, 101), 5))
new_data.sort()
data.append(new_data)
for i, d in enumerate(data):
print('{}: {}'.format(i, d))
print('\nMerged:')
for i in heapq.merge(*data):
print(i, end='')
print()
由于merge()
的实现使用一个堆,所以它会根据所合并的序列个数消费内存,而不是根据这些序列中的元素个数。