归并排序详解及应用
什么是归并排序
1.归并排序明明就是一个数组算法,和二叉树有什么关系?
首先我们要知道的是所有递归的算法,你甭管它是干什么的,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码,你要写递归算法,本质上就是要告诉每个节点需要做什么。
然后看看归并排序的代码框架:
// 定义:排序 nums[lo..hi]
void sort(int[] nums, int lo, int hi) {
if (lo == hi) {
return;
}
int mid = (lo + hi) / 2;
// 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid);
// 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
/****** 后序位置 ******/
// 此时两部分子数组已经被排好序
// 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi);
/*********************/
}
// 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
// 合并为有序数组 nums[lo..hi]
void merge(int[] nums, int lo, int mid, int hi);
显然,归并排序就是先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。
到这里我们可以想二叉树的后序遍历
/* 二叉树遍历框架 */
void traverse(TreeNode root) {
if (root == null) {
return;
}
traverse(root.left);
traverse(root.right);
/****** 后序位置 ******/
print(root.val);
/*********************/
}
再进一步联想求二叉树的最大深度
// 定义:输入根节点,返回这棵二叉树的最大深度
int maxDepth(TreeNode root) {
if (root == null) {
return 0;
}
// 利用定义,计算左右子树的最大深度
int leftMax = maxDepth(root.left);
int rightMax = maxDepth(root.right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
int res = Math.max(leftMax, rightMax) + 1;
return res;
}
可以看出这三个代码框架都很像,都是先处理左右子问题在合并处理根问题。
因此得出结论:归并排序的过程可以在逻辑上抽象成一棵二叉树,树上的每个节点的值可以认为是 nums[lo..hi]
,叶子节点的值就是数组中的单个元素
然后,在每个节点的后序位置(左右子节点已经被排好序)的时候执行 merge
函数,合并两个子节点上的子数组:
这个 merge
操作会在二叉树的每个节点上都执行一遍,执行顺序是二叉树后序遍历的顺序。nums[lo…hi] 理解成二叉树的节点
, sort
函数理解成二叉树的遍历函数
所以完整代码如下:
//归并排序(二叉树的后序遍历思想)
class Merge_Sort{
//声明辅助数组,用于装nums,使得nums可以原地排序
private static int[] temp;//用来装合并后的
//开始排序
public static void sort(int[] nums){
//为辅助数组开辟空间
temp=new int[nums.length];
//实际开始排序(原地修改)
sort(nums,0,nums.length-1);
}
public static void sort(int[] nums,int low,int high){
int mid=low+(high-low)/2;
//如果是单个数说明不用排序
if(low==high) return;
//对左子树进行排序
sort(nums,low,mid);
//对右子树进行排序
sort(nums,mid+1,high);
//对左右子树进行合并
merge(nums,low,mid,high);
}
//对左右子树进行合并
public static void merge(int[] nums,int low,int mid,int high){
int i=low;//用来标记左子树
int j=mid+1;//用来标记右子树
//先进行迁移,使得可以原地排序
for(int t=low;t<=high;t++){
temp[t]=nums[t];
//System.out.println(nums[t]);
}
// 数组双指针技巧,合并两个有序数组
for(int k=low;k<=high;k++){
if(i==mid+1){
//说明左子树已经全部合并完了
nums[k]=temp[j++];
}else if(j==high+1){
//说明右子树已经全部合并完了
nums[k]=temp[i++];
}else if(temp[i]>temp[j]){
//如果都没有合并完,那就比较大小来合并
nums[k]=temp[j++];
}else{
nums[k]=temp[i++];
}
}
}
}
应用:以下题目的原理都是通过给归并排序的 merge
函数加一些私货完成目标。
一、315. 计算右侧小于当前元素的个数
分析:为什么会用到归并排序?
主要是用到了归并排序的合并,我们在使用 merge
函数合并两个有序数组的时候,其实是可以知道一个元素 nums[i]
后边有多少个元素比 nums[i]
小的。
这时候我们应该把 temp[i]
放到 nums[p]
上,因为 temp[i] < temp[j]
。
但就在这个场景下,我们还可以知道一个信息:5 后面比 5 小的元素个数就是 左闭右开区间 [mid + 1, j)
中的元素个数,即 2 和 4 这两个元素:
换句话说,在对 nums[lo..hi]
合并的过程中,每当执行 nums[p] = temp[i]
时,就可以确定 temp[i]
这个元素后面比它小的元素个数为 j - mid - 1
。
这道题的解题步骤是:
- 声明一个新的对象(类型),用于记录原数组的元素值和在数组中的原始索引。因为在排序过程中,每个元素的索引位置会不断改变,所以我们用一个
Pair
类封装每个元素及其在原始数组nums
中的索引,以便count
数组记录每个元素之后小于它的元素个数。
归并排序所用到的辅助数组和新的数组都是该类型的数组
2.开始归并排序
3.合并两个有序数组,以下两种情况更新count
- 当右子树遍历完成后
- 当右指针所指的数大于左指针的数时(即右子树大于左子树时)
class Solution {
private class Pair {
int val, id;
Pair(int val, int id) {
// 记录数组的元素值
this.val = val;
// 记录元素在数组中的原始索引
this.id = id;
}
}
// 归并排序所用的辅助数组
private Pair[] temp;
// 记录每个元素后面比自己小的元素个数
private int[] count;
// 主函数
public List<Integer> countSmaller(int[] nums) {
int n = nums.length;
count = new int[n];
temp = new Pair[n];
Pair[] arr = new Pair[n];
// 记录元素原始的索引位置,以便在 count 数组中更新结果
for (int i = 0; i < n; i++)
arr[i] = new Pair(nums[i], i);
// 执行归并排序,本题结果被记录在 count 数组中
sort(arr, 0, n - 1);
List<Integer> res = new LinkedList<>();
for (int c : count) res.add(c);
return res;
}
// 归并排序
private void sort(Pair[] arr, int lo, int hi) {
if (lo == hi) return;
int mid = lo + (hi - lo) / 2;
sort(arr, lo, mid);
sort(arr, mid + 1, hi);
merge(arr, lo, mid, hi);
}
// 合并两个有序数组
private void merge(Pair[] arr, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = arr[i];
}
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
arr[p] = temp[j++];
} else if (j == hi + 1) {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
} else if (temp[i].val > temp[j].val) {
arr[p] = temp[j++];
} else {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
}
}
}
}
二、493. 翻转对
int count=0;
int[] temp;
public int reversePairs(int[] nums) {
temp=new int[nums.length];
sort(nums,0,nums.length-1);
return count;
}
public void sort(int[] nums,int low,int high){
if(low==high) return;
int mid=low+(high-low)/2;
sort(nums,low,mid);
sort(nums,mid+1,high);
merge(nums,low,mid,high);
}
public void merge(int[] nums,int low,int mid,int high){
//if(low==high) return;
for(int k=low;k<=high;k++){
temp[k]=nums[k];
}
int i=0,j=0;
//在合并前夹带私货
for(i=low;i<=mid;i++){
for(j=mid+1;j<=high;j++){
if(temp[i]>temp[j]*2){
count++;
}
}
}
i=low;
j=mid+1;
for(int p=low;p<=high;p++){
if(i==mid+1){
nums[p]=temp[j++];
}else if(j==high+1){
nums[p]=temp[i++];
}else if(temp[i]>temp[j]){
nums[p]=temp[j++];
}else{
nums[p]=temp[i++];
}
}
}
但是如果代码是这样写的话会出现下面截图的这种错误,因为nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kYn0RLge-1680231094015)(D:\Development\Typora\img\image-20230329112234000.png)]
因此修改为下面的代码
int count=0;
int[] temp;
public int reversePairs(int[] nums) {
temp=new int[nums.length];
sort(nums,0,nums.length-1);
return count;
}
public void sort(int[] nums,int low,int high){
if(low==high) return;
int mid=low+(high-low)/2;
sort(nums,low,mid);
sort(nums,mid+1,high);
merge(nums,low,mid,high);
}
public void merge(int[] nums,int low,int mid,int high){
//if(low==high) return;
for(int k=low;k<=high;k++){
temp[k]=nums[k];
}
int i=0,j=0;
//在合并前夹带私货
for(i=low;i<=mid;i++){
for(j=mid+1;j<=high;j++){
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
if((long)temp[i]>(long)temp[j]*2){
count++;
}
}
}
i=low;
j=mid+1;
for(int p=low;p<=high;p++){
if(i==mid+1){
nums[p]=temp[j++];
}else if(j==high+1){
nums[p]=temp[i++];
}else if(temp[i]>temp[j]){
nums[p]=temp[j++];
}else{
nums[p]=temp[i++];
}
}
}
但是这个解法会出现超时问题,因为额外添加了一个嵌套for循环
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8H0AX7V6-1680231094017)(D:\Development\Typora\img\image-20230329112359188.png)]
那如何进行优化呢?
注意子数组 nums[low..mid]
是排好序的,也就是 nums[i] <= nums[i+1]
所以对于对于 nums[i], low <= i <= mid
,我们在找到的符合 nums[i] > 2*nums[j]
的 nums[j], mid+1 <= j <= high
,也必然也符合 nums[i+1] > 2*nums[j]
。
也就是说,我们不用每次都傻乎乎地去遍历整个 nums[mid+1..high]
,只要维护一个开区间边界 end
,维护 nums[mid+1..end-1]
是符合条件的元素即可。
如何理解上面的那句话呢?
就是遍历到i时,end为end1,此时统计次数
下一次遍历到i+1时,end1若还满足nums[i+1]>nums[end1]*2,则继续end++,直到找到新的end2,退出循环统计次数
至于为什么在end1的基础上++,是因为nums[i] <= nums[i+1],这样就达到了不用重复遍历整个 nums[mid+1…high]
所以代码优化如下:
int count=0;
int[] temp;
public int reversePairs(int[] nums) {
temp=new int[nums.length];
sort(nums,0,nums.length-1);
return count;
}
public void sort(int[] nums,int low,int high){
if(low==high) return;
int mid=low+(high-low)/2;
sort(nums,low,mid);
sort(nums,mid+1,high);
merge(nums,low,mid,high);
}
public void merge(int[] nums,int low,int mid,int high){
//if(low==high) return;
for(int k=low;k<=high;k++){
temp[k]=nums[k];
}
int i=0,j=0;
//在合并前夹带私货
// 进行效率优化,维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
// 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
//也就是说只要找到一个
int end = mid + 1;
for (i = low; i <= mid; i++) {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
while (end <= high && (long)nums[i] > (long)nums[end] * 2) {
end++;
}
count += end - (mid + 1);
}
i=low;
j=mid+1;
for(int p=low;p<=high;p++){
if(i==mid+1){
nums[p]=temp[j++];
}else if(j==high+1){
nums[p]=temp[i++];
}else if(temp[i]>temp[j]){
nums[p]=temp[j++];
}else{
nums[p]=temp[i++];
}
}
}
三、327. 区间和的个数
提到区间和我们就要想到用前缀和数组,前缀和数组两个元素之差就是区间和
所以这道题就是构建一个前缀和数组,对前缀和数组进行归并排序,在合并之前对前缀和数组求区间和,这里的寻找区间有点像滑动窗口(让窗口中的元素和 nums[i]
的差落在 [lower, upper]
中)
不知道这段代码为什么不能得到正确答案?
class Solution {
int count=0;
int lower;
int upper;
long[] temp;
public int countRangeSum(int[] nums, int lower, int upper) {
//这里借助了前缀和数组,注意前缀和数组两个元素之差就是区间和
//因此我们要构建出来一个前缀和,对前缀和数组进行归并排序,在合并之前
//对前缀和数组求区间和,这里的寻找有点像滑动窗口
this.lower=lower;
this.upper=upper;
//1.构建前缀和数组
int n=nums.length;
long[] preSum=new long[n+1];
for(int i=0;i<n;i++){
preSum[i+1]=preSum[i]+(long)nums[i];
}
//2.对前缀和数组进行归并排序
//2.1声明辅助数组
temp=new long[n];
sort(preSum,0,n-1);
return count;
}
//3.完成归并排序
public void sort(long[] preSum,int low,int high){
if(low==high) return;
int mid=low+(high-low)/2;
sort(preSum,low,mid);
sort(preSum,mid+1,high);
merge(preSum,low,mid,high);
}
//4.合并
public void merge(long[] preSum,int low,int mid,int high){
for(int j=low;j<=high;j++){
temp[j]=preSum[j];
}
//5.在合并之前找到合适的区间和并统计
int start=mid+1;//start表示第一个满足区间和范围的前缀和到preSum[i]
int end=mid+1;//end表示最后一个满足的
//那他们之间的差就是个数
// 维护左闭右开区间 [start, end) 中的元素落在 [lower, upper] 中
for(int i=low;i<=mid;i++){
while(start<=high&&preSum[start]-preSum[i]<lower){
//不在区间范围内则继续往后找
start++;
}
while(end<=high&&preSum[end]-preSum[i]<=upper){
//在范围内则继续找还有没有更大范围
end++;
}
count+=end-start;
}
//6.开始正式合并
int i=low,j=mid+1;
for(int p=low;p<=high;p++){
if(i==mid+1){
preSum[p]=temp[j++];
}else if(j==high+1){
preSum[p]=temp[i++];
}else if(temp[i]>temp[j]){
preSum[p]=temp[j++];
}else{
preSum[p]=temp[i++];
}
}
}
}
正确答案:
class Solution {
int lower, upper;
public int countRangeSum(int[] nums, int lower, int upper) {
this.lower = lower;
this.upper = upper;
long[] preSum = new long[nums.length + 1];
for (int i = 0; i < nums.length; i++) {
preSum[i + 1] = (long) nums[i] + preSum[i];
}
sort(preSum);
return count;
}
// 用于辅助合并有序数组
private long[] temp;
private int count = 0;
public void sort(long[] nums) {
// 先给辅助数组开辟内存空间
temp = new long[nums.length];
// 排序整个数组(原地修改)
sort(nums, 0, nums.length - 1);
}
// 定义:将子数组 nums[lo..hi] 进行排序
private void sort(long[] nums, int lo, int hi) {
if (lo == hi) {
// 单个元素不用排序
return;
}
// 这样写是为了防止溢出,效果等同于 (hi + lo) / 2
int mid = lo + (hi - lo) / 2;
// 先对左半部分数组 nums[lo..mid] 排序
sort(nums, lo, mid);
// 再对右半部分数组 nums[mid+1..hi] 排序
sort(nums, mid + 1, hi);
// 将两部分有序数组合并成一个有序数组
merge(nums, lo, mid, hi);
}
// 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
private void merge(long[] nums, int lo, int mid, int hi) {
// 先把 nums[lo..hi] 复制到辅助数组中
// 以便合并后的结果能够直接存入 nums
for (int i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 这段代码会超时
// for (int i = lo; i <= mid; i++) {
// // 在区间 [mid + 1, hi] 中寻找 lower <= delta <= upper 的元素
// for (int k = mid + 1; k <= hi; k++) {
// long delta = nums[k] - nums[i];
// if (delta <= upper && delta >= lower) {
// count++;
// }
// }
// }
// 进行效率优化
// 维护左闭右开区间 [start, end) 中的元素落在 [lower, upper] 中
int start = mid + 1, end = mid + 1;
for (int i = lo; i <= mid; i++) {
while (start <= hi && nums[start] - nums[i] < lower) {
start++;
}
while (end <= hi && nums[end] - nums[i] <= upper) {
end++;
}
count += end - start;
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
// 左半边数组已全部被合并
nums[p] = temp[j++];
} else if (j == hi + 1) {
// 右半边数组已全部被合并
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
}
总结
所有递归的算法,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码。你要写递归算法,本质上就是要告诉每个节点需要做什么。
对于这个专题的归并排序算法,递归的 sort
函数就是二叉树的遍历函数,而 merge
函数就是在每个节点上做的事情