2路归并排序的算法描述
1. 归并过程
1.1. 若有2个有序的序列a和b,我们要做的是将序列a和序列b合并,存放到一个新数组c,使c的元素有序。为了简单起见,我们这里仅考虑升序排列,即序列a和序列b是非递减序列。c中的元素也以非递减方式排列。
1.2. 有2个指针pa和pb分别指向a和b的第一个元素。指针pc指向新数组c。
1.3. 当a和b都没有处理完时,检查a头部的元素(pa指向的元素)x, 和b头部的元素(pb指向的元素)y, 若x<=y, 将x放入新数组,同时将pa向右移动一个位置。否则将y放入新数组,同时将pb向右移动一个位置。将x或者y放入新数组后c,将pc移动一个位置。
1.4. 不断重复步骤1.3,直到a或者b处理完毕。
1.5. 如果a还有未处理的元素,将a中的剩余元素添加到c。
1.6. 如果b还有未处理的元素,将b中的剩余元素添加到c。
下面的代码中,merge_v1 和merge_v2 和 merge_v2_asm 实现相同的功能,将arr1,arr2合并(len1和len2为arr1和arr2的长度),存放到数组target。
2. 函数merge_sort
函数merge_sort实现归并排序,归并排序需要一块临时的buff,我们取tlen=(len+1)/2,len为原始序列长度,分配一块长度为tlen缓冲区tbuff。然后调用merge_sort_core做归并排序。
3. 函数merge_sort_core的流程
3.1. 首先检查数组长度, 若数组长度小于某个值INSERT_SORT_THRESHOLD,做直接插入排序并返回。
3.2. 将序列arr分为几乎相等的2部分,若len是偶数,则左半部分长度和右半部分长度是一样的,否则左半部分长度比右半部分长度多1个元素。
3.3. 首先递归调用自己,对左半部分进行排序,然后再将右半部分进行排序。
3.4. 将左半部分复制到临时缓冲区tbuff,将tbuff作为数组a,将右半部分作为数组b, 将原始数组arr作为目标数组c。调用程序merge_v1进行做归并操作。
以下是所有源代码:
#include <stdio.h>
#include <stdlib.h>
#include "sorts.h"
void merge_v1(ELE_TYPE arr1[], int len1, ELE_TYPE arr2[], int len2, ELE_TYPE *target)
{
ELE_PTR p1,p2;
ELE_PTR p1End,p2End;
p1=arr1; p1End=arr1+len1;
p2=arr2; p2End=arr2+len2;
while ( p1<p1End && p2<p2End)
{
if ( *p1 <= *p2 )
*target++ = *p1++;
else
*target++ = *p2++;
}
while ( p1< p1End )
*target++ = *p1++;
while ( p2< p2End )
*target++ = *p2++;
}
void merge_v2(ELE_TYPE arr1[], int len1, ELE_TYPE arr2[], int len2, ELE_TYPE *target)
{
ELE_PTR p1, p2, p1End, p2End;
ELE_TYPE a, b, m1, m2;
p1 = arr1; p1End = arr1 + len1;
p2 = arr2; p2End = arr2 + len2;
while ( p1<p1End && p2<p2End)
{
a = *p1;
b = *p2;
// if ( a>b) then m1=1, m2=0 else m1=0, m2=1
m1 = (a < b);
m2 = (m1 ^ 1);
p1 += m1;
p2 += m2;
*target++ = (m1 ? a : b);
}
while (p1< p1End)
*target++ = *p1++;
while (p2< p2End)
*target++ = *p2++;
}
_declspec(naked)
void merge_v2_asm(ELE_TYPE arr1[], int len1, ELE_TYPE arr2[], int len2, ELE_TYPE *target)
{
#define _arr1 4
#define _len1 8
#define _arr2 12
#define _len2 16
#define _target 20
#define _OFS_P1END 0
#define _OFS_P2END 4
#define _ST_SIZE 24
#define REG_p1 esi
#define REG_p2 edi
#define REG_TGT ebp
#define REG_a eax
#define REG_b ebx
#define REG_M1 ecx
#define REG_M1L cl
#define REG_M2 edx
__asm
{
push esi ; save registers
push edi
push ebx
push ebp
sub esp,(_ST_SIZE-16)
mov REG_p1, dword ptr [esp+_ST_SIZE+_arr1]
mov eax, dword ptr [esp+_ST_SIZE+_len1]
lea edx, [REG_p1+eax*4]
mov dword ptr [esp+_OFS_P1END], edx
mov REG_p2, dword ptr [esp+_ST_SIZE+_arr2]
mov eax, dword ptr [esp+_ST_SIZE+_len2]
lea edx, [REG_p2+eax*4]
mov dword ptr [esp+_OFS_P2END], edx
mov REG_TGT, dword ptr [esp+_ST_SIZE+_target]
jmp merge_v2_cmp
merge_v2_loop_start:
xor REG_M1, REG_M1 ; REG_M1 = 0
mov REG_a, dword ptr [REG_p1]
mov REG_b, dword ptr [REG_p2]
cmp REG_a, REG_b
setl REG_M1L ; if a<b, then REG_M1 = 1
cmovge REG_a, REG_b ; if a >= b, then a = b
mov REG_M2, REG_M1
lea REG_p1, DWORD PTR[REG_p1 + REG_M1 * 4] ; if a<b, REG_p1++
xor REG_M2, 1 ; if a>=b, then REG_M2=1, else REG_M2=0
mov DWORD PTR[REG_TGT], REG_a ; *target = a
lea REG_p2, DWORD PTR[REG_p2 + REG_M2 * 4] ; if a>=b, REG_p2++
add REG_TGT, 4 ; target++
merge_v2_cmp:
cmp REG_p1, DWORD PTR [esp+_OFS_P1END]
jae SHORT merge_v2_p1_tail_cmp
cmp REG_p2, DWORD PTR [esp+_OFS_P2END]
jb merge_v2_loop_start
jmp merge_v2_p1_tail_cmp
merge_v2_p1_tail_loop:
mov eax, DWORD PTR [REG_p1]
mov [REG_TGT], eax
add REG_p1, 4
add REG_TGT, 4
merge_v2_p1_tail_cmp:
cmp REG_p1, DWORD PTR [esp+_OFS_P1END]
jb merge_v2_p1_tail_loop
jmp merge_v2_p2_tail_cmp
merge_v2_p2_tail_loop:
mov eax, DWORD PTR [REG_p2]
mov [REG_TGT], eax
add REG_p2, 4
add REG_TGT, 4
merge_v2_p2_tail_cmp:
cmp REG_p2, DWORD PTR [esp+_OFS_P2END]
jb merge_v2_p2_tail_loop
merge_v2_exit:
add esp, (_ST_SIZE-16)
pop ebp ; restore registers
pop ebx
pop edi
pop esi
ret
}
}
void merge_sort_core(ELE_TYPE arr[], int len, ELE_TYPE *tBuff)
{
int left_half;
int right_half;
#if 0
if ( len<=1)
return ;
#else
if (len <= INSERT_SORT_THRESHOLD)
{
insert_sort(arr, len);
return ;
}
#endif
left_half = (len+1)/2;
right_half = len - left_half;
merge_sort_core(arr,left_half, tBuff);
merge_sort_core(arr+left_half, right_half, tBuff);
memcpy(tBuff,arr,sizeof(ELE_TYPE)*left_half);
merge_v1(tBuff, left_half, arr+left_half, right_half, arr);
//merge_v2(tBuff, left_half, arr + left_half, right_half, arr);
//merge_v2_asm(tBuff, left_half, arr+left_half, right_half, arr);
}
void merge_sort(ELE_TYPE arr[], int len)
{
ELE_TYPE *tBuff= (ELE_TYPE*)malloc( sizeof(ELE_TYPE)*(len+1)/2);;
merge_sort_core(arr,len, tBuff);
free(tBuff);
}
void test_merge_sort()
{
ELE_TYPE arr[] = { 61, 17, 29, 22, 34, 60, 72, 21, 50, 1, 62 };
int len = (int) sizeof(arr) / sizeof(arr[0]);
printf("original data are:");
print_array(arr, len);
merge_sort(arr, len);
printf("The data after sorted are:");
print_array(arr, len);
}
归并排序的核心部分为归并操作,其热点是下面语句的分支语句。我们知道,分支语句的执行涉及到CPU的分支预测功能。如果分支预测的准确度较低,会降低程序的执行速度。在下面的场景中,如果数据分布是随机的,两个分支执行的概率各占50%,故我们可考虑使用消除分支技术来提高CPU的执行性能。
if ( *p1 <= *p2 )
*target++ = *p1++;
else
*target++ = *p2++;
函数merge_v2是一个消除分支的版本,其核心部分如下。
a = *p1;
b = *p2;
// if ( a>b) then m1=1, m2=0 else m1=0, m2=1
m1 = (a < b);
m2 = (m1 ^ 1);
p1 += m1;
p2 += m2;
*target++ = (m1 ? a : b);
函数 merge_v2_asm是函数merge_v2的一个汇编语言实现,相对于C编译器,这个汇编版本减少了4条指令。在I7-4700HQ的测试结果显示,merge_v2_asm是最快的版本。对2百万个整数进行排序。merge_v1需要159毫秒,merge_v2需要181毫秒,merge_v2_asm需要142毫秒。当然,这并不说明,merge_v2一定慢于merge_v1,在某些CPU,merge_v2反超merge_v1是可能的。读者如果感兴趣,可在自己的电脑上测试一下这三个版本的性能。另外,我的测试结果显示,在某些情况下,归并排序的性能可超越快速排序。具体测试数据将在后续的文章中给出。