目录
快速排序
版本一
特点:简单粗暴
要有一个预判,如果数量小于2,就默认为排好了,然后,再这个的前提下,选定一个基准数(就把序列的第一个数当做基准数吧)比这个数大的放到一个序列里面,比这个数小的放到另外一个序列里面,最后把结果合并起来。用递归实现的话,那么"return array"就是对应的递归出口了,当不断递归,知道数组的长度小于2的时候,停止递归,进行返回出栈。
def quicksort(array):
size = len(array)
if not array or size < 2: # NOTE: 递归出口,空数组或者只有一个元素的数组都是有序的
return array
pivot_idx = 0
pivot = array[pivot_idx]
less_part = [array[i] for i in range(size) if array[i] <= pivot and pivot_idx != i]
great_part = [array[i] for i in range(size) if array[i] > pivot and pivot_idx != i]
return quicksort(less_part) + [pivot] + quicksort(great_part)
缺陷:需要额外的空间去存储,空间复杂度高
版本二
不占用额外的空间,更优。
将过程分为两个函数来实现,设置首尾两个指针,如果左边指示位置的值比右边指示位置的值大的时候, 就递归结束。
def quicksort_inplace(array, beg, end): # 注意这里我们都用左闭右开区间
if beg < end: # beg == end 的时候递归出口
pivot = partition(array, beg, end)
quicksort_inplace(array, beg, pivot)
quicksort_inplace(array, pivot + 1, end)
就是把比pivot小的分配到左边,比他大的分配到右边:
简单叙述一下思路,就是先随机的设置基准值(默认为第一个数)设为P,然后,从第二个数(设位置为l)往后和倒数第一个数(设位置为r)往前开始遍历,当找到一个,第l个数比p大,第r个数比p小,把这俩数交换,然后,接着再重新按刚才的方式找,直到l的位置值比r大了位置停下来,然后,把这个位置的值作为新的P返回出去,作为下一次新的基准值继续查找。
直至函数
def partition(array, beg, end):
"""对给定数组执行 partition 操作,返回新的 pivot 位置"""
pivot_index = beg
pivot = array[pivot_index]
left = pivot_index + 1
right = end - 1 # 开区间,最后一个元素位置是 end-1 [0, end-1] or [0: end),括号表示开区间
while True:
# 从左边找到比 pivot 大的
while left <= right and array[left] < pivot:
left += 1
while right >= left and array[right] >= pivot:
right -= 1
if left > right:
break
else:
array[left], array[right] = array[right], array[left]
array[pivot_index], array[right] = array[right], array[pivot_index]
return right # 新的 pivot 位置
完整代码
# -*- coding: utf-8 -*-
def quicksort(array):
size = len(array)
if not array or size < 2: # NOTE: 递归出口,空数组或者只有一个元素的数组都是有序的
return array
pivot_idx = 0
pivot = array[pivot_idx]
less_part = [array[i] for i in range(size) if array[i] <= pivot and pivot_idx != i]
great_part = [array[i] for i in range(size) if array[i] > pivot and pivot_idx != i]
return quicksort(less_part) + [pivot] + quicksort(great_part)
def test_quicksort():
import random
seq = list(range(10))
random.shuffle(seq)
assert quicksort(seq) == sorted(seq) # 用内置的sorted 『对拍』
def quicksort_inplace(array, beg, end): # 注意这里我们都用左闭右开区间
if beg < end: # beg == end 的时候递归出口
pivot = partition(array, beg, end)
quicksort_inplace(array, beg, pivot)
quicksort_inplace(array, pivot + 1, end)
def partition(array, beg, end):
"""对给定数组执行 partition 操作,返回新的 pivot 位置"""
pivot_index = beg
pivot = array[pivot_index]
left = pivot_index + 1
right = end - 1 # 开区间,最后一个元素位置是 end-1 [0, end-1] or [0: end),括号表示开区间
while True:
# 从左边找到比 pivot 大的
while left <= right and array[left] < pivot:
left += 1
while right >= left and array[right] >= pivot:
right -= 1
if left > right:
break
else:
array[left], array[right] = array[right], array[left]
array[pivot_index], array[right] = array[right], array[pivot_index]
return right # 新的 pivot 位置
def test_partition():
l = [4, 1, 2, 8]
assert partition(l, 0, len(l)) == 2
l = [1, 2, 3, 4]
assert partition(l, 0, len(l)) == 0
l = [4, 3, 2, 1]
assert partition(l, 0, len(l)) == 3
l = [1]
assert partition(l, 0, len(l)) == 0
l = [2,1]
assert partition(l, 0, len(l)) == 1
def test_quicksort_inplace():
import random
seq = list(range(10))
random.shuffle(seq)
sorted_seq = sorted(seq)
quicksort_inplace(seq, 0, len(seq))
assert seq == sorted_seq
def nth_element(array, beg, end, nth):
"""查找一个数组第 n 大元素"""
if beg < end:
pivot_idx = partition(array, beg, end)
if pivot_idx == nth - 1: # 数组小标从 0 开始
return array[pivot_idx]
elif pivot_idx > nth - 1:
return nth_element(array, beg, pivot_idx, nth)
else:
return nth_element(array, pivot_idx + 1, end, nth)
def test_nth_element():
l1 = [3, 5, 4, 2, 1]
assert nth_element(l1, 0, len(l1), 3) == 3
assert nth_element(l1, 0, len(l1), 2) == 2
l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for i in l:
assert nth_element(l, 0, len(l), i) == i
for i in reversed(l):
assert nth_element(l, 0, len(l), i) == i
array = [3, 2, 1, 5, 6, 4]
assert nth_element(array, 0, len(array), 2) == 2
array = [2,1]
assert nth_element(array, 0, len(array), 1) == 1
assert nth_element(array, 0, len(array), 2) == 2
array = [3,3,3,3,3,3,3,3,3]
assert nth_element(array, 0, len(array), 1) == 3
if __name__ == '__main__':
test_nth_element()