归并排序简介
算法思想:
归并排序的核心思想是采用分治策略,将整个数组的排序任务分类为两个子问题,前一半排序和后一半排序,然后整合两个有序部分完成整体排序。即把数组分为若干个子序列,直到单个元素组成一个序列,然后将各阶段得到的序列组合在一起得到最终完整排序序列。
归并排序任务可以如下分治完成:
1. 把前一半排序
2. 把后一半排序
3. 把两半归并到一个新的有序数组,然后再拷贝回原数组,排序完成。
执行样例:
输入:[29,10,14,37,14,25,10]
算法实现:
// 将数组 a 的局部 a[s,m] 和 a[m+1,e] 合并到 tmp, 并保证 tmp 有序,然后再拷贝回 a[s,m]
void merge(vector<int>& arr, int start, int mid, int end, vector<int> tmp){
int pTmp = 0;
int pLeft = start; int pRight = mid+1;
while(pLeft<=mid&&pRight<=end){
if(arr[pLeft] < arr[pRight]){
tmp[pTmp++] = arr[pLeft++];
}else{
tmp[pTmp++] = arr[pRight++];
}
}
while(pLeft<=mid){
tmp[pTmp++] = arr[pLeft++];
}
while (pRight<=end)
{
tmp[pTmp++] = arr[pRight++];
}
for(int i=0;i<pTmp;i++){
arr[start+i] = tmp[i];
}
}
// 归并排序递归调用,先排前半部分,在排后半部分,最后将两部分结果合并
void mergeSort(vector<int>& arr, int start, int end, vector<int> tmp){
if(start < end){
int mid = start + (end-start)/2;
mergeSort(arr,start,mid,tmp);
mergeSort(arr,mid+1,end,tmp);
merge(arr,start,mid,end,tmp);
}
}
493 翻转对
给定一个数组 nums ,如果 i < j
且 nums[i] > 2*nums[j]
就将 (i,j)
称作一个翻转对,返回给定数组中翻转对的数量。
输入一个数组,输出一个整数表示数组中翻转对的个数
输入: [2,4,3,5,1]
输出: 3
解释: (4,1),(3,1),(5,1)三个翻转数对
解析:
本题的实质是求数组中的逆序数,但是本题中对逆序数有新的定义。
我们基于归并排序算法采用分治策略:将排列分为两部分,先求左半部分的翻转对,然后求右半部分的翻转对;最后算两部分之间存在的翻转对,时间复杂度 O(nlogn)
。
左半边和右半边都是排好序的。例如,都是从大到小排序的,左右半边只需要从头到尾各扫一遍,就可以找出由两边各取一个数构成的翻转对数。
下面给出了一种较为容易理解的实现方法,在归并排序代码的基础上做了很小的修改,就是当左侧元素大于右侧时开始寻找翻转对。但是,本题的数据规模最大可达到50000,如果使用这种简单循环遍历将导致超时。
class Solution {
public:
void mergeAndCount(vector<int>& nums, int start, int mid, int end, vector<int> tmp, int& count){
if(start>end) return;
int pTmp = 0;
int pLeft = start, pRight = mid+1;
while(pLeft<=mid && pRight<=end){
if(nums[pLeft] < nums[pRight]){
tmp[pTmp++] = nums[pLeft++];
}else{
// 这种循环遍历方法将导致超时
for(int i=pLeft;i<=mid;++i){
if(nums[i] > (long long)2*nums[pRight]){
++count;
}
}
tmp[pTmp++] = nums[pRight++];
}
}
while(pLeft<=mid){
tmp[pTmp++] = nums[pLeft++];
}
while(pRight<=end){
tmp[pTmp++] = nums[pRight++];
}
for(int i=0;i<pTmp;++i){
nums[start+i] = tmp[i];
}
}
void mergeSort(vector<int>& nums, int start, int end, vector<int> tmp, int& count){
if(start<end){
int mid = start + (end-start)/2;
mergeSort(nums,start,mid,tmp,count);
mergeSort(nums,mid+1,end,tmp,count);
mergeAndCount(nums,start,mid,end,tmp,count);
}
}
int reversePairs(vector<int>& nums) {
int ans = 0;
vector<int> tmp(nums.size());
mergeSort(nums,0,nums.size()-1,tmp,ans);
return ans;
}
};
为了降低时间复杂度,我们采用如下策略:如果左半边 A1 和右半边 A2 都是排好序的,我们就可以在线性时间内解决这个问题了。当然,也可以用二分查找来解决,但是时间复杂度就是线性对数的了。
-
初始化两个指针i,j分别指向A1,A2的头部
-
如果
A1[i] > 2*A2[j]
,那么A1[i]
及其后面的所有元素都符合要求,更新答案并后移j
,否则,后移i
-
最后,合并
A1, A2
以备解决后面更大的子问题使用,并返回前结果
class Solution {
public:
int find_reversed_pairs(vector<int>& nums, int start, int end){
int res = 0,mid = start + (end-start)/2;
int i = start,j = mid+1;
for(;i <= mid;i++){
while(j <= end && (long)nums[i] > 2*(long)nums[j]) {
res += mid - i + 1;
j++;
}
}
return res;
}
int merge_sort(vector<int>& nums, int start, int end, vector<int> tmp){
if(start >= end) return 0;
int mid = start + (end-start) / 2;
int res = merge_sort(nums,tmp,start,mid) +
merge_sort(nums,tmp,mid+1,end) +
find_reversed_pairs(nums,start,end);
int i = start,j = mid+1,ind = start;
while(i <= mid && j <= end){
if(nums[i] <= nums[j]) tmp[ind++] = nums[i++];
else tmp[ind++] = nums[j++];
}
while(i <= mid) tmp[ind++] = nums[i++];
while(j <= end) tmp[ind++] = nums[j++];
for(int ind = start;ind <= end;ind++) nums[ind] = tmp[ind];
return res;
}
int reversePairs(vector<int>& nums) {
if(nums.empty()) return 0;
vector<int>& tmp(nums.size());
return merge_sort(nums,0,nums.size()-1,tmp);
}
};
148 排序链表
给定链表的头结点 head
,请将其按 升序 排列并返回 排序后的链表 ,要求在 O(nlogn)
的时间复杂度内完成。
输入一个链表,输出一个按升序排列的链表。
输入:head = [4,2,1,3]
输出:[1,2,3,4]
解析:
在完成本题之前可以先尝试 21 合并两个有序链表、23 合并K个升序链表 两题,也许对你解决本题有帮助。
本题我们可以使用归并排序的思想完成。采用分治策略,先将链表划分为两个部分,然后在每个部分中进行排序。所以,我们需要完成如下两个任务:
- 寻找链表中点划分链表
- 排序局部链表,并合并结果
第一个任务我们可以使用快慢指针完成,快指针一次走两步,慢指针一次走一步,这样当快指针到达链表末尾时,慢指针刚好指向链表中点。然后,我们递归地采用上诉方法将链表不断划分为更小的子问题,直到不可再分。这里需要注意一个常犯的错误:在寻找中点时不能直接以 nullptr 判断快指针是否达到终点,而是要将其与链表尾节点 tail 进行比较。
第二个任务升序合并链表:分割完成之后,迭代合并两部分链表,将链表按照升序合并,最终形成完整的升序链表。
class Solution {
public:
// 合并两个链表
ListNode* merge(ListNode *l1, ListNode *l2){
if(!l1 || !l2) return l1?l1:l2;
ListNode dummy;
ListNode *tail = &dummy;
ListNode *p1 = l1, *p2 = l2;
while(p1 && p2){
if(p1->val < p2->val){
tail->next = p1;
p1 = p1->next;
}else{
tail->next = p2;
p2 = p2->next;
}
tail = tail->next;
}
tail->next = p1?p1:p2;
tail = nullptr;
return dummy.next;
}
// 寻找链表终点并递归划分链表
ListNode* mergeSort(ListNode* head, ListNode* tail){
if(!head) return nullptr;
if(head->next == tail){
head->next = nullptr;
return head;
}
ListNode *slow = head, *fast = head;
// // 错误示范,会导致越界
// while(fast&&fast->next){
// fast = fast->next->next;
// slow = slow->next;
// }
while(fast != tail){
slow = slow->next;
fast = fast->next;
if(fast != tail){
fast = fast->next;
}
}
ListNode* mid = slow;
return merge(mergeSort(head,mid),mergeSort(mid,tail));
}
ListNode* sortList(ListNode* head) {
return mergeSort(head,nullptr);
}
};