这个涉及到原地归并排序和block swap两个算法,所以如果有兴趣的读者先读我空间的 Parallel In-Place Merge 和 block swapping两篇文章。
首先给出原地归并排序的基本原理图:
首先把给定的数组分成两部分,前一部分包括[l,m]中的元素,而后者则包括(m,r]中的元素。然后找到第一个数组中的中间值,也就是q1=(l+m)/2,q1位置的元素就是我们要找的元素,大家可以举个例子自己算一下,当找到q1的时候,q1前面的元素有 q1 个,这对接下来的编程很有影响,所以这点一定要弄清楚。这样第一个数组我们就分为了两部分。接下来我们用q1去划分(m,r]这段元素,也是分成两部分。建议在取元素范围的时候,(m,r]这段的前一部分个数为(q2-1-m)个,后一部分为(r - q2 +1)个,这样约定后思路会清晰。接下来划分好了之后就是交换,有一种叫做block swapping的算法,这个算法能在O(1)的空间复杂度下交换两个长度不相同的相邻存储区的数组。
个人觉得能力有限,不能讲的很清楚,希望真的有兴趣的人可以看下我推荐的两篇博文,老外讲的很经典,其他都是细节的问题。
#include<stdio.h>
#include<assert.h>
void swap(int *a, int low, int high) {
while(low < high) {
int temp = *(a + low);
*(a + low) = *(a + high);
*(a + high) = temp;
low++;
high--;
}
}
void block_exchange(int *a,int low, int mid, int high) {
swap(a, low, mid);
swap(a, mid + 1, high);
swap(a, low, high);
}
int binary_search(int value, int *a,int low, int high) {
/*
如果数组中存在要查找的元素,那么返回high的位置的前后都有可能等于value的值:
a: 1,2,4,4,7,9 value=4 mid=2返回的high位置的元素后有值等于value
b: 1,2,4,4,7,9,10 value=4 mid=3,返回的high位置的元素前有值等于value
c: 1,2,7,9 value=4 high=3,返回的high值之前的元素都小于value,包括high在内的以后得元素都大于value
*/
assert(a != NULL);
while(low < high) {
int mid = low + (high - low) / 2;
if(value <= a[mid])
high = mid;
else
low = mid + 1;
}
return high;
}
void merge_in_place(int *a, int low, int mid, int high) {
int length1 = mid - low +1;
int length2 = high - mid;
if(!(length1 >= 0 && length2 >= 0))
return ;
if(length1 >= length2) {
if(length2 <= 0)
return;
int q1 = (low + mid) / 2;
int q2 = binary_search(a[q1], a, mid + 1 , high);
int q3 = q1 + (q2 - 1 - mid);
block_exchange(a, q1, mid, q2 - 1);
merge_in_place(a, low, q1 - 1, q3 - 1);
merge_in_place(a, q3 + 1, q2 - 1, high);
} else {
if(length1 <= 0)
return;
int q1 = (mid + 1 + high) /2;
int q2 = binary_search(a[q1], a, low, mid);
int q3 = q2 + (q1 - 1 - mid);
block_exchange(a, q2, mid, q1);
merge_in_place(a, low, q2 - 1, q3 -1);
merge_in_place(a, q3 + 1, q1, high);
}
}
void main() {
int a[]={1,3,5,7,9,2,4,6,8,10};
int len = sizeof(a) /sizeof(int);
int mid = len / 2 - 1;
merge_in_place(a, 0, mid, len - 1);
for(int i = 0; i < len; i++)
printf("%d\t",a[i]);
printf("\n");
}