分治法
归并排序算法是采用分治法的一个典型应用。
分治法是设计算法的一种策略,包含三步:
1、分 把原问题分解成若干子问题
2、治 (递归地)解决子问题
3、合 合并子问题的解,得到原问题的解
归并排序算法工作原理
不断将数组一分为二,递归成一个个的子数组,得到最小子数组的排序,并合并:
1. 申请两个空间存放子序列数据,子序列是排好序的
2. 设定两个指针,分别指向两个序列的起始位置
3. 比较两个指针所指向的元素,选择相对小的元素放入到合并空间,并移动指针到下一位置
4. 重复步骤3直到某一指针超出序列尾,将另一序列剩下的所有元素直接复制到合并序列尾。
合并的过程其实就是将两个有序数组合并为一个有序数组。
C语言版本
#include<stdio.h>
#include<stdlib.h>
void merge(int a[], int start, int mid, int end)
{
int i, j, k, n1, n2;
int * front, * back; //变量用于申请两个内存空间
n1 = mid - start + 1; //前一序列长度
n2 = end - mid; //后一序列长度
//申请两个空间存放两个序列值
front = (int *) malloc (n1 * sizeof(int)) ;
back = (int *) malloc (n2 * sizeof(int));
/*将值设置到新序列中*/
for (i = 0; i < n1; i++)
{
front[i] = a[start + i];
}
for (i = 0; i < n2; i++)
{
back[i] = a[mid + i + 1];
}
/*将元素合并*/
i = 0;
j = 0;
k = start;
// i j 指针分别指向两个新序列的起始位置
while (i < n1 && j < n2)
{
if (front[i] < back[j])
{
a[k++] = front[i++];
}
else
{
a[k++] = back[j++];
}
}
/* 合并剩余元素 */
while (i < n1) // 有点巧妙
{
a[k++] = front[i++];
}
while (j < n2)
{
a[k++] = back[j++];
}
}
void sort(int a[], int start, int end)
{
int mid;
if (start < end)
{
mid = (start + end) / 2; // 分解成两个序列
sort(a, start, mid); // 递归将左边序列分解到子问题
sort(a, mid + 1, end); // 递归将右边序列分解到子问题
merge(a, start, mid, end); // 合并子问题的解
}
}
int main()
{
int a[8] = {1, 4, 3, 9, 6, 5, 8, 7};
int i;
sort(a, 0, 7);
for (i = 0; i < 8; i++)
{
printf("%d\t", a[i]);
}
printf("\n");
return 0;
}
归并结果示意图:
光看代码没准想不明白,上图模拟了一下计算机的执行,把递归的过程都画出来了,一直到递归结束:start < end
。
有数组A[1…n],递归分解子问题A[1…n/2]和A[n/2+1…n]一直到最小子问题A[1](只有一个元素的数组),然后最小子问题A[1]本身就是解,并合并子问题的解向上一直到合并完成。
Java版本
先看一个合并两个有序数组到目标数组的程序:
import java.util.Arrays;
public class MergeSort {
/**
* 将两个有序数组合并到目标数组中
* @param a 目标数组
* @param start 目标数组起始位置,从起始位置开始填入元素
* @param one
* @param two
*/
private void mergeArray(int[] a, int start, int[] one, int[] two) {
int oneLen = one.length;
int twoLen = two.length;
// m 指针指向 one数组, n 指针指向two数组
int m = 0, n = 0;
// i 指针指向 a数组
int i = start;
while (m < oneLen && n < twoLen) {
if (one[m] < two[n]) {
a[i] = one[m];
m++;
} else {
a[i] = two[n];
n++;
}
i++;
}
while (m < oneLen) {
a[i] = one[m];
i++;
m++;
}
while (n < twoLen) {
a[i] = two[n];
i++;
n++;
}
}
public static void main(String[] args) {
MergeSort mergeSort = new MergeSort();
int a[] = new int[12];
int[] one = new int[]{1, 16, 80, 200, 201};
int[] two = new int[]{10, 32, 45, 79, 90, 100, 101};
mergeSort.mergeArray(a, 0, one, two);
System.out.println(Arrays.toString(a));
}
}
更符合人类思维习惯的Java版本:
import java.util.Arrays;
public class MergeSort {
/**
* 将两个有序数组合并到目标数组中
* @param a 目标数组
* @param start 目标数组起始位置,从起始位置开始填入元素
* @param one
* @param two
*/
private void mergeArray(int[] a, int start, int[] one, int[] two) {
int oneLen = one.length;
int twoLen = two.length;
// m 指针指向 one数组, n 指针指向two数组
int m = 0, n = 0;
// i 指针指向 a数组
int i = start;
while (m < oneLen && n < twoLen) {
if (one[m] < two[n]) {
a[i] = one[m];
m++;
} else {
a[i] = two[n];
n++;
}
i++;
}
while (m < oneLen) {
a[i] = one[m];
i++;
m++;
}
while (n < twoLen) {
a[i] = two[n];
i++;
n++;
}
}
private void merge(int[] a, int start, int mid, int end) {
int frontLen = mid - start + 1;
int backLen = end - mid;
int[] front = new int[frontLen];
int[] back = new int[backLen];
// 初始化front和back数组
for (int i = 0; i < frontLen; i++) {
front[i] = a[start + i];
}
for (int i = 0; i < backLen; i++) {
back[i] = a[mid + 1 + i];
}
mergeArray(a, start, front, back);
}
public void sort(int[] a, int start, int end) {
if (start < end) {
int mid = (start + end) / 2;
sort(a, start, mid);
sort(a, mid + 1, end);
merge(a, start, mid, end);
}
}
public static void main(String[] args) {
MergeSort mergeSort = new MergeSort();
int[] a = {1, 4, 3, 9, 6, 5, 8, 7};
mergeSort.sort(a, 0, 7);
System.out.println(Arrays.toString(a));
}
}
时间复杂度
归并排序算法写成递归式:
总的时间复杂度描述:T(n) = 2T(n / 2) + O(n),
2表示子问题的数目,n / 2是每个子问题的规模,O(n)是分治所用的时间,
假设O(n) = cn,表示整个问题的规模为cn。
使用递归树来计算时间复杂度,递归式:T(n) = 2T(n / 2) + cn,得到递归树:
![归并排序递归树](https://i-blog.csdnimg.cn/blog_migrate/8d800af940c539c1f94a40d05ca6826f.png)
这课树的高度是lgn,每层所耗的时间是cn,总的时间就是:cn * lgn + θ(n),计算时间复杂度,其实就是在做渐进分析,这样将c和θ(n)忽略,我们的时间复杂度就是O(nlgn)。
或者通过分治的思想,需要O(lgn)层的分治操作,每层的归并时间复杂度是O(n), 所以这个时间复杂度就是O(nlgn)。
参考网易公开课算法课程: https://open.163.com/movie/2010/12/G/F/M6UTT5U0I_M6V2T1JGF.html