定义
- 堆是一个完全二叉树
- 堆中每一个节点的值都必须大于等于(或小于等于)其子树中每个节点的值
如何存储一个堆
完全二叉树比较适合用数组来存储
数组中下标为 i 的节点的
- 左子节点: i * 2的节点
- 右子节点: i * 2 + 1的节点
- 父节点: i / 2
重要操作
插入一个数据: 新插入数据放到数组最后,然后中从下往上堆化
删除堆顶元素:把数组最后一个元素放到堆顶,然后从上往下堆化
堆化的时间复杂度是O(logn)
应用场景
- 堆排序 O(nlogn)
- 建堆 O(n)
2.排序 O(nlogn)
- 建堆 O(n)
- 优先级队列
- 合并有序小文件
- 高性能定时器
- 赫夫曼编码、图最短路径、最小生成树
- 求topK问题
- 中位数
- 双堆实现求中位数或百分位数
为什么快速排序要比堆排序性能好?
- 堆排序数据访问的方式没有快速排序友好
快排是顺序访问的,而对于堆排序来说,是跳着访问的,这样对CPU缓存是不友好的 - 对于同样的数据,在排序过中,堆排序算法的数据交换次数要多于快速排序
实现
class Heap:
def __init__(self, capacity):
# 下标从1开始用
self._data = [0] * (capacity +1)
self._capacity = capacity
self._count = 0
def _parent(self,child_index):
return child_index // 2
def _left(self, parent_index):
return 2 * parent_index
def _right(self, parent_index):
return 2 * parent_index + 1
def _siftup(self):
i , parent = self._count, self._parent(self._count)
while parent and self._data[i] > self._data[parent]:
self._data[i], self._data[parent] = self._data[parent], self._data[i]
i , parent = parent, self._parent(parent)
def _siftdown(cls, a , count, root_index = 1):
i = larger_child_index = root_index
while True:
left, right = cls._left(i), cls._right(i)
if left <= count and a[i] < a[left]:
larger_child_index = left
if right <= count and a[larger_child_index] < a[right]:
larger_child_index = right
if larger_child_index == i:break
a[i], a[larger_child_index] = a[larger_child_index], a[i]
i = larger_child_index
def insert(self, value):
if self._count >=self._capacity:return
self._count += 1
self._data[self._count] = value
self._siftup()
def remove_max(self):
if self._count:
result = self._data[1]
self._data[1] = self._data[self._count]
self._count -= 1
self._siftdown(self._data, self._count)
return result
return None
def build_heap(self, a):
for i in range((len(a)-1)//2 , 0 , -1):
self._siftdown(a, len(a) - 1, i)
def sort(cls, a):
cls.build_heap(a)
k = len(a) - 1
while k> 1:
a[1], a[k] = a[k], a[1]
k -= 1
cls._siftdown(a,k)
def __repr__(self):
return self._data[1:self._count + 1].__repr__()
hp = Heap(10)
hp.insert(3)
hp.insert(9)
hp.insert(1)
hp.insert(8)
hp.insert(7)
hp.insert(3)
print(hp)
for _ in range(6):
print(hp.remove_max())
a = [ 3,2,1,5,6,4]
a = [0] + a
hp.sort(a)
print(a[-2])