假定:有 1 个乱序的数列 nums ,其中有 n 个数。
要求:排好序之后是 从小到大 的顺序。
堆排序算法
原理
-
先将原始的堆,调整为最大堆:
从倒数第 1 个有子结点的结点(下标为 index = n//2 - 1)开始,将以结点 index 为根结点的子堆调整为最大堆;
index 范围是 n//2 - 1(含) 到 0(含);
最后,原始的堆调整为最大堆。 -
将最大堆调整为(从小到大排列)有序的最小堆:
a. 将最大堆的堆顶第 0 项与第 n-1 项交换位置,这样,第 1 大值排在堆的尾部;
b. 这样,除开第 n-1 项的剩余项组成的子堆,不是最大堆也不是最小堆,将这个子堆调整为最大堆;
c. 重复上述过程 a、b;
最终得到 1 个从小到大排列的最小堆。
关键代码的我的理解:
filter_down 函数:
特别要说明其中的 if-else。如果 根结点 的值 rootVal 较大,则不将其往下过滤(break);如果 rootVal 比子结点小,则将较大值往上移动,为 rootVal 腾出空间,最终将 rootVal 下滤。
值得注意的是,filter_down 函数本身,只是将堆中的最大项调整到顶部,较小项调整到尾部,并不会将 子堆 完全 调整为 最大堆,(结合测试用例中的 nums6、nums7 ,画图辅助理解。)
只有在 heap_sort 中,从后往前调用 filter_down 函数(第 1 个 for 循环),这样的方式,才能得到完整的最大堆。
代码
# coding:utf-8
from swap import swap
"""
将二叉树中的,以 nums[p] 为根的子堆,
调整为最大堆。
(下滤 根节点 nums[p])
p 是根节点所在的下标;
n 是当前堆一共有多少个元素。
"""
def filter_down(nums, p, n):
''' 取出根节点存放的值 '''
rootVal = nums[p]
parentIdx = p
while 2*parentIdx+1 <= n-1:
''' kidIdx 暂时指向左儿子 '''
kidIdx = 2*parentIdx+1
''' 左儿子 kidIdx 不等于 n-1,即左儿子不是最后 1 个结点,
也就是说,还有右儿子 '''
if kidIdx != n-1 and nums[kidIdx] < nums[kidIdx+1]:
kidIdx += 1
''' kidIdx 指向较大的子结点 '''
''' rootVal 找到了合适的位置 '''
if rootVal >= nums[kidIdx]:
break
''' 如果 rootVal 比子结点小,
则下滤 rootVal。
(将较大值往上移动)'''
else:
nums[parentIdx] = nums[kidIdx]
''' 将父节点的指针往下移动 '''
parentIdx = kidIdx
nums[parentIdx] = rootVal
def heap_sort(nums):
n = len(nums)
''' 建立最大堆 '''
for index in range(n//2-1, -1, -1):
filter_down(nums, index, n)
''' 删除最大堆的顶部的最大项
将最大堆的最大项与最后 1 项交换位置,
然后将除最后 1 项的剩余部分,调整为最大堆;
重复上面的操作。
'''
for index in range(n-1, 0, -1):
swap(nums, 0, index)
filter_down(nums, 0, index)
def test():
nums0 = [1,2,3]
nums1 = [1,3,2]
nums2 = [2,1,3]
nums3 = [2,3,1]
nums4 = [3,1,2]
nums5 = [3,2,1]
nums6 = [5,1,6,3,4,8]
nums7 = [6,1,8,3,4,5]
for nums in (nums0, nums1, nums2, nums3, nums4, nums5, nums6, nums7):
filter_down(nums, 0, len(nums))
assert nums0 == [3,2,1]
assert nums1 == [3,1,2]
assert nums2 == [3,1,2]
assert nums3 == [3,2,1]
assert nums4 == [3,1,2]
assert nums5 == [3,2,1]
assert nums6 == [6,1,8,3,4,5]
assert nums7 == [8,1,6,3,4,5]
nums = [7,4,5,3,8,9]
heap_sort(nums)
assert nums == [3,4,5,7,8,9]
print('Pass!')
算法复杂度
时间复杂度:
最坏情况下 ;
最好情况下 ;
平均情况 。
空间复杂度:
。
稳定性
。
参考文献
- 《数据结构(第 2 版)》 - 浙江大学 - P148——P151、P265——P267;
- 《数据结构与算法 Python 语言描述》 - 裘宗燕 - 北京大学 - ;
- https://github.com/henry199101/sort/blob/master/heap_sort.py。