算法原理
定义merge_sort
函数,将数组分为两部分A[left..mid]
和A[mid+1..right]
,分别对这两部分排序(递归调用merge_sort
函数),然后调用merge
函数合并这两部分
void merge_sort( vector<int>& A, int left, int right )
{
// 如果left == right,说明数组只包含1个元素,不必排序,直接返回
if( left >= right ) // 似乎left == right也没问题
return;
int mid = left + ( right - left ) / 2;
merge_sort( A, left, mid ); // 递归调用
merge_sort( A, mid + 1, right ); // 递归调用
merge( A, left, mid, right );
}
定义merge
函数,将排好序的两部分A[left..mid]
和A[mid+1..right]
合并起来
void merge( vector<int>& A, int left, int mid, int right )
{
// 复制A[left..mid]和A[mid+1..right],暂存
vector<int> arr1( A.begin() + left, A.begin() + mid + 1 );
vector<int> arr2( A.begin() + mid + 1, A.begin() + right + 1 );
arr1.push_back( INT_MAX ); // 末尾加上INT_MAX,作为哨兵
arr2.push_back( INT_MAX );
int p1 = 0, p2 = 0;
// 遍历A[left..right]的每个位置,依次填写arr1[p1]或arr2[p2]
for( int i = left; i <= right; i++ )
if( arr1[p1] <= arr2[p2] )
{
A[i] = arr1[p1];
p1++;
}
else
{
A[i] = arr2[p2];
p2++;
}
}
完整代码如下
#include <stdio.h>
#include <time.h>
#include <vector>
#include <algorithm>
#define MAXN 50
using namespace std;
void display( vector<int>& s )
{
for( int i = 0; i < s.size(); i++ )
printf( "%d ", s[i] );
printf( "\n" );
}
bool check( vector<int>& s1, vector<int>& s2 )
{
for( int i = 0; i < s1.size(); i++ )
if( s1[i] != s2[i] )
return false;
return true;
}
void merge( vector<int>& A, int left, int mid, int right )
{
// 复制A[left..mid]和A[mid+1..right],暂存
vector<int> arr1( A.begin() + left, A.begin() + mid + 1 );
vector<int> arr2( A.begin() + mid + 1, A.begin() + right + 1 );
arr1.push_back( INT_MAX ); // 末尾加上INT_MAX,作为哨兵
arr2.push_back( INT_MAX );
int p1 = 0, p2 = 0;
// 遍历A[left..right]的每个位置,依次填写arr1[p1]或arr2[p2]
for( int i = left; i <= right; i++ )
if( arr1[p1] <= arr2[p2] )
{
A[i] = arr1[p1];
p1++;
}
else
{
A[i] = arr2[p2];
p2++;
}
}
void merge_sort( vector<int>& A, int left, int right )
{
// 如果left == right,说明数组只包含1个元素,不必排序,直接返回
if( left >= right ) // 似乎left == right也没问题
return;
int mid = left + ( right - left ) / 2;
merge_sort( A, left, mid ); // 递归调用
merge_sort( A, mid + 1, right ); // 递归调用
merge( A, left, mid, right );
}
int main()
{
srand( (unsigned)time(NULL) );
int nums[MAXN];
for( int i = 0; i < MAXN; i++ )
nums[i] = rand() % 100;
vector<int> s1( nums, nums + MAXN );
vector<int> s2( nums, nums + MAXN );
merge_sort( s1, 0, s1.size() - 1 );
sort( s2.begin(), s2.end() );
printf( "s1: " );
display(s1);
printf( "s2: " );
display(s2);
printf( "Result: %s\n", check( s1, s2 ) ? "Accepted" : "Wrong Answer" );
return 0;
}
算法分析
时间复杂度:merge_sort函数O(logn),merge函数O(n),故时间复杂度为O(nlogn)
空间复杂度:merge_sort函数O(logn),merge函数O(n),故O(logn) + O(n) = O(n)
扩展:对链表使用Merge Sort
class Solution {
public:
/*
链表高级排序,采用归并排序,借用了21. Merge Two Sorted Lists的代码
*/
ListNode* sortList(ListNode* head) {
if( !head || !head->next )
return head;
// 利用快慢指针,找到中间节点
ListNode* prev = NULL, *p = head, *q = head;
while(q)
{
prev = p;
p = p->next;
for( int i = 0; i < 2; i++ )
{
q = q->next;
if( !q )
break;
}
}
prev->next = NULL; // 切成两段
ListNode* l1 = sortList( head );
ListNode* l2 = sortList( p );
return mergeTwoLists( l1, l2 );
}
ListNode* mergeTwoLists( ListNode* l1, ListNode* l2 )
{
ListNode* head = new ListNode(-1);
ListNode* p = l1, *q = l2, *cur = head;
while( p && q )
{
if( p->val <= q->val )
{
cur->next = p;
p = p->next;
}
else
{
cur->next = q;
q = q->next;
}
cur = cur->next;
}
if( p )
cur->next = p;
else if( q )
cur->next = q;
return head->next;
}
};
应用:求数组中的逆序对
简单版本:nowcoder 数组中的逆序对
在Merge Sort中(具体是在merge
函数中)完成逆序对的统计,Merge Sort仍然对整个数组进行递增排序,因此破坏了原数组
class Solution {
public:
void merge( vector<int>& s, int left, int mid, int right, int& result )
{
vector<int> arr1( s.begin() + left, s.begin() + mid + 1 );
vector<int> arr2( s.begin() + mid + 1, s.begin() + right + 1 );
// 因为merge操作是从右向左进行的,因此在最左边插入INT_MIN作为哨兵
arr1.insert( arr1.begin(), INT_MIN );
arr2.insert( arr2.begin(), INT_MIN );
int p1 = arr1.size() - 1;
int p2 = arr2.size() - 1;
for( int i = right; i >= left; i-- )
if( arr1[p1] > arr2[p2] )
{
result = ( result + p2 ) % 1000000007; // 产生p2对逆序对
s[i] = arr1[p1];
p1--;
}
else
{
s[i] = arr2[p2];
p2--;
}
}
void merge_sort( vector<int>& s, int left, int right, int& result )
{
if( left >= right )
return;
int mid = left + ( right - left ) / 2;
merge_sort( s, left, mid, result );
merge_sort( s, mid + 1, right, result );
merge( s, left, mid, right, result );
}
int InversePairs(vector<int> data) {
int result = 0;
merge_sort( data, 0, data.size() - 1, result );
return result;
}
};