最近在知乎上看到一个帖子,总结了各种常见的排序算法,并用python一一实现了,不过归并排序的迭代写法,题主说他不会写,我就试了一下,其实很简单。下面会先分析递归的时候实际上做了哪些事,然后迭代如何重现这些事。先用C++写,因为估计看这篇博客的大部分人对C++比较熟,最后会分享python的版本,实现过程基本一模一样。
递归的时候做了什么?
先po一下递归的伪代码:
// 区间[head1, head2-1]和[head2, tail2]都是排好序的,现在需要合并
void mergeSorted(arr, head1, head2, tail2) {
// balabala...
}
void mergeSort(arr, left, right) {
if (left >= right)
return;
mid = (left + right) >> 1;
mergeSort(arr, left, mid);
mergeSort(arr, mid+1, right);
mergeSorted(arr, left, mid+1, right);
}
可以看出,递归的时候,并没有做什么特别的事,只是从中间分成两半,每一半自己去做排序,最后合并起来,是后序遍历,从叶子节点往回看:
1. 区间的长度都为1,直接返回,不用合并;
2. 区间的长度为2,两个子区间都排好序了,将它们合并起来;
3. 区间的长度为4,两个子区间都排好序了,将它们合并起来;
4. ……
迭代怎么写?
从上面的分析可以看出,其实只需要枚举步长1,2,4,……,对由每个步长分开的区间,都合并一下。
比如,一开始数组为[8 7 6 5 4 3 2 1]。
第一遍,步长为1,将相邻的两个区间合并(注意加粗黑体):
7 8 6 5 4 3 2 1
7 8 5 6 4 3 2 1
7 8 5 6 3 4 2 1
7 8 5 6 3 4 1 2
第二遍,步长为2,将相邻的两个区间合并(注意加粗黑体):
5 6 7 8 3 4 1 2
5 6 7 8 1 2 3 4
第三遍,步长为4,将相邻的两个区间合并(注意加粗黑体):
1 2 3 4 5 6 7 8
应该很简单就写出来吧?注意一下边界即可:
// 区间[head1, head2-1]和[head2, tail2]都是排好序的,现在需要合并
void mergeSortHelper(vector<int>& v, int head1, int head2, int tail2) {
int tail1 = head2 - 1, index = 0, len = tail2 - head1 + 1, start = head1;
vector<int> tmp(len);
while (head1 <= tail1 || head2 <= tail2) {
if (head1 > tail1)
tmp[index++] = v[head2++];
else if (head2 > tail2)
tmp[index++] = v[head1++];
else {
if (v[head1] <= v[head2])
tmp[index++] = v[head1++];
else
tmp[index++] = v[head2++];
}
}
for (int i = 0; i < len; ++i)
v[start+i] = tmp[i];
}
void mergeSort(vector<int>& v) {
int len = v.size();
// 倍进枚举步长1,2,4,……
for (int step = 1; step <= len; step <<= 1) {
int offset = step + step;
for (int index = 0; index < len; index += offset)
mergeSortHelper(v, index, min(index+step, len-1), min(index+offset-1, len-1));
}
}
总体的测试代码:
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;
// 注意被我注释掉的地方,解开来,很直观可以看到排序的过程是怎么做的!
void display(const vector<int>& v) {
for (int i = 0; i < v.size(); ++i)
cout << v[i] << ' ';
cout << endl;
}
bool isSorted(const vector<int>& v) {
vector<int> sorted(v.begin(), v.end());
sort(sorted.begin(), sorted.end());
for (int i = 0; i < v.size(); ++i)
if (v[i] != sorted[i])
return false;
return true;
}
void mergeSortHelper(vector<int>& v, int head1, int head2, int tail2) {
int tail1 = head2 - 1, index = 0, len = tail2 - head1 + 1, start = head1;
// cout << "Before " << head1 << ' ' << tail1 << ' ' << head2 << ' ' << tail2 << endl;
// display(v);
vector<int> tmp(len);
while (head1 <= tail1 || head2 <= tail2) {
if (head1 > tail1)
tmp[index++] = v[head2++];
else if (head2 > tail2)
tmp[index++] = v[head1++];
else {
if (v[head1] <= v[head2])
tmp[index++] = v[head1++];
else
tmp[index++] = v[head2++];
}
}
for (int i = 0; i < len; ++i)
v[start+i] = tmp[i];
// cout << "After ";
// display(v);
// cout << endl;
}
void mergeSort(vector<int>& v) {
int len = v.size();
for (int step = 1; step <= len; step <<= 1) {
int offset = step + step;
for (int index = 0; index < len; index += offset)
mergeSortHelper(v, index, min(index+step, len-1), min(index+offset-1, len-1));
}
}
void gen(vector<int>& v, size_t size) {
static const int MAX = 99997;
v = vector<int>(size);
for (int i = 0; i < size; ++i)
v[i] = rand() % MAX;
}
int main() {
// vector<int> v;
// for (int i = 0; i < 10; ++i)
// v.push_back(10-i);
// mergeSort(v);
srand(time(0));
for (size_t size = 0; size < 10000; ++size) {
vector<int> v;
gen(v, size);
mergeSort(v);
if (!isSorted(v)) {
cout << "FAIL with size = " << size << endl;
break;
} else {
cout << "GOOD with size = " << size << endl;
}
}
return 0;
}
用python来实现
实现原理跟上面说的一样,直接po代码了:
# -*- coding:utf-8 -*-
import random
# 合并两个已排好序的区间:[head1, tail1]与[head2, tail2]
def mergeSortHelper(v, head1, head2, tail2):
tail1 = head2 - 1
start = head1
index = 0
tmp = [0] * (tail2-head1+1)
while head1 <= tail1 or head2 <= tail2:
if head1 > tail1:
tmp[index] = v[head2]
elif head2 > tail2:
tmp[index] = v[head1]
else:
if v[head1] <= v[head2]:
tmp[index] = v[head1]
else:
tmp[index] = v[head2]
if head1 <= tail1 and tmp[index] == v[head1]:
head1 += 1
else:
head2 += 1
index += 1
for i in range(start, tail2+1):
v[i] = tmp[i-start]
def mergeSort(v):
length = len(v)
step = 1
# 步长为1,2,4,8,...,一直合并下去
while step <= length:
offset = step << 1
for index in range(0, length, offset):
mergeSortHelper(v, index, min(index+step, length-1), min(index+offset-1, length-1))
step = offset
# 随机生成大小为size的数组
def genData(size):
MAX = 99997
v = [0] * size
for i in range(size):
v[i] = random.randrange(0, MAX)
return v
# 验证v是否真的排好序了
def isSorted(v):
sortedV = sorted(v)
for i in range(len(v)):
if v[i] != sortedV[i]:
return False
return True
if __name__ == '__main__':
for size in range(0, 10000):
v = genData(size)
mergeSort(v)
if not isSorted(v):
print('Fail at size = {0}'.format(size))
else:
print('Good at size = {0}'.format(size))