简述
我们常见的二路归并排序基本上分为两步,时间复杂度为 O ( l o g 2 n ) O(log_{2}n) O(log2n),空间复杂度为 O ( n ) O(n) O(n)。下面是我们常用的写法:
void merge(int *data, int low, int mid, int high)
{
int *tmp = new int[high + 1];
int i = low;
int j = mid + 1;
for (int k = low; k <= high; k++)
{
tmp[k] = data[k];
}
int k = low;
for (k = low; i <= mid && j <= high;)
{
if (tmp[i] < tmp[j])
{
data[k++] = tmp[i++];
}
else
{
data[k++] = tmp[j++];
}
}
while (i <= mid && k <= high)
{
data[k++] = tmp[i++];
}
while (j <= high && k <= high)
{
data[k++] = tmp[j++];
}
delete tmp;
}
void sort_merge(int *data, int low, int high)
{
if (low < high)
{
int mid = (low + high) / 2;
sort_merge(data, low, mid);
sort_merge(data, mid + 1, high);
merge(data, low, mid, high);
}
}
但是我们可以看到,从始至终我们都对每个子数组进行归并,但是我们知道对于小规模的数组一般的插入排序,选择排序或者希尔排序性能表现得更好,因此在小规模的数据上没有必要进行二路归并。第二,每次在merge中进行数据的拷贝本身就很大消耗。第三,无论数组是否有序我们都会二路归并到底,这对于整个算法来说不太好。因此,对应的改进办法如下。
改进办法
- 利用data[mid] < data[mid + 1]进行有序性判断。因为经过归并我们可以保证左右两个子数组都是有序的,只要第一个子数组的最后一个元素小于第二个字数组的第一个元素数组一定是有序的;
- 对小规模的数据使用简单的线性排序算法,比如希尔排序。这里选择希尔,选择排序,插入排序都可以(ps:冒泡不太建议,但是有兴趣的可以试下)。
- 在每次归并后交换辅助数组和数据数组的角色。也就是在merge中不需要将子数组的值复制回数据数组,直接将数据数组排序进辅助数组,在执行结束后交换二者的角色即可。顺便说下,我这个没有实现,我觉得作者的意思是使用系统调用类似于memcpy一次性复制元素来提高每次merge时单个单个复制的性能。
实现代码,我这里使用的是希尔排序进行小规模排序:
void merge_shell(int *data, int low, int high)
{
int gap = 1;
while (gap < (high - low + 1) / 3) gap = 3 * gap + 1;
while (gap >= 1)
{
for (int i = gap + low; i <= high; i++)
{
for (int j = i; j >= (gap + low) && data[j] < data[j - gap]; j -= gap)
{
swap(&data[j], &data[j - gap]);
}
}
gap = gap / 3;
}
}
void merge(int *data, int *tmp, int low, int mid, int high)
{
int i = low;
int j = mid + 1;
int k = low;
for (k = low; i <= mid && j <= high;)
{
if (data[i] < data[j])
{
tmp[k++] = data[i++];
}
else
{
tmp[k++] = data[j++];
}
}
while (i <= mid && k <= high)
{
tmp[k++] = data[i++];
}
while (j <= high && k <= high)
{
tmp[k++] = data[j++];
}
for (int i = low; i <= high; i++)
{
data[i] = tmp[i];
}
}
void merge_sort(int *data, int *tmp, int low, int high)
{
if (low < high)
{
if ((high - low + 1) <= 10)
{
merge_shell(data, low, high);
}
else
{
int mid = (low + high) / 2;
merge_sort(data, tmp, low, mid);
merge_sort(data, tmp, mid + 1, high);
if (data[mid] < data[mid + 1]) //the array has been sorted, we do not need to sort again
{
return;
}
merge(data, tmp, low, mid, high);
}
}
}