归并排序的学习

文章详细阐述了归并排序的原理,将其与二叉树的后序遍历相联系,并通过三个具体的LeetCode题目(315.计算右侧小于当前元素的个数、493.翻转对、327.区间和的个数)说明如何利用归并排序的合并过程解决复杂问题。文章强调了在归并排序的merge函数中添加额外逻辑来实现特定目标的方法。
摘要由CSDN通过智能技术生成

归并排序详解及应用

什么是归并排序

1.归并排序明明就是一个数组算法,和二叉树有什么关系?

首先我们要知道的是所有递归的算法,你甭管它是干什么的,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码,你要写递归算法,本质上就是要告诉每个节点需要做什么。

然后看看归并排序的代码框架:

// 定义:排序 nums[lo..hi]
void sort(int[] nums, int lo, int hi) {
    if (lo == hi) {
        return;
    }
    int mid = (lo + hi) / 2;
    // 利用定义,排序 nums[lo..mid]
    sort(nums, lo, mid);
    // 利用定义,排序 nums[mid+1..hi]
    sort(nums, mid + 1, hi);

    /****** 后序位置 ******/
    // 此时两部分子数组已经被排好序
    // 合并两个有序数组,使 nums[lo..hi] 有序
    merge(nums, lo, mid, hi);
    /*********************/
}

// 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
// 合并为有序数组 nums[lo..hi]
void merge(int[] nums, int lo, int mid, int hi);

显然,归并排序就是先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。

到这里我们可以想二叉树的后序遍历

/* 二叉树遍历框架 */
void traverse(TreeNode root) {
    if (root == null) {
        return;
    }
    traverse(root.left);
    traverse(root.right);
    /****** 后序位置 ******/
    print(root.val);
    /*********************/
}

再进一步联想求二叉树的最大深度

// 定义:输入根节点,返回这棵二叉树的最大深度
int maxDepth(TreeNode root) {
	if (root == null) {
		return 0;
	}
	// 利用定义,计算左右子树的最大深度
	int leftMax = maxDepth(root.left);
	int rightMax = maxDepth(root.right);
	// 整棵树的最大深度等于左右子树的最大深度取最大值,
    // 然后再加上根节点自己
	int res = Math.max(leftMax, rightMax) + 1;

	return res;
}

可以看出这三个代码框架都很像,都是先处理左右子问题在合并处理根问题。

因此得出结论:归并排序的过程可以在逻辑上抽象成一棵二叉树,树上的每个节点的值可以认为是 nums[lo..hi],叶子节点的值就是数组中的单个元素

在这里插入图片描述

然后,在每个节点的后序位置(左右子节点已经被排好序)的时候执行 merge 函数,合并两个子节点上的子数组:

在这里插入图片描述

这个 merge 操作会在二叉树的每个节点上都执行一遍,执行顺序是二叉树后序遍历的顺序。nums[lo…hi] 理解成二叉树的节点, sort 函数理解成二叉树的遍历函数

所以完整代码如下:

//归并排序(二叉树的后序遍历思想)
class Merge_Sort{
    //声明辅助数组,用于装nums,使得nums可以原地排序
    private static int[] temp;//用来装合并后的
    //开始排序
    public static void sort(int[] nums){
        //为辅助数组开辟空间
        temp=new int[nums.length];
        //实际开始排序(原地修改)
        sort(nums,0,nums.length-1);
    }
    public static void sort(int[] nums,int low,int high){
        int mid=low+(high-low)/2;
        //如果是单个数说明不用排序
        if(low==high) return;
        //对左子树进行排序
        sort(nums,low,mid);
        //对右子树进行排序
        sort(nums,mid+1,high);
        //对左右子树进行合并
        merge(nums,low,mid,high);
    }
    //对左右子树进行合并
    public static void merge(int[] nums,int low,int mid,int high){
        int i=low;//用来标记左子树
        int j=mid+1;//用来标记右子树
        //先进行迁移,使得可以原地排序
        for(int t=low;t<=high;t++){
            temp[t]=nums[t];
            //System.out.println(nums[t]);
        }
        // 数组双指针技巧,合并两个有序数组
        for(int k=low;k<=high;k++){
            if(i==mid+1){
                //说明左子树已经全部合并完了
                nums[k]=temp[j++];
            }else if(j==high+1){
                //说明右子树已经全部合并完了
                nums[k]=temp[i++];
            }else if(temp[i]>temp[j]){
                //如果都没有合并完,那就比较大小来合并
                nums[k]=temp[j++];
            }else{
                nums[k]=temp[i++];
            }
        }
    }
}

应用:以下题目的原理都是通过给归并排序的 merge 函数加一些私货完成目标。

一、315. 计算右侧小于当前元素的个数

分析:为什么会用到归并排序?

主要是用到了归并排序的合并,我们在使用 merge 函数合并两个有序数组的时候,其实是可以知道一个元素 nums[i] 后边有多少个元素比 nums[i] 小的

这时候我们应该把 temp[i] 放到 nums[p] 上,因为 temp[i] < temp[j]

在这里插入图片描述

但就在这个场景下,我们还可以知道一个信息:5 后面比 5 小的元素个数就是 左闭右开区间 [mid + 1, j) 中的元素个数,即 2 和 4 这两个元素:

在这里插入图片描述

换句话说,在对 nums[lo..hi] 合并的过程中,每当执行 nums[p] = temp[i] 时,就可以确定 temp[i] 这个元素后面比它小的元素个数为 j - mid - 1

这道题的解题步骤是:
  1. 声明一个新的对象(类型),用于记录原数组的元素值和在数组中的原始索引。因为在排序过程中,每个元素的索引位置会不断改变,所以我们用一个 Pair 类封装每个元素及其在原始数组 nums 中的索引,以便 count 数组记录每个元素之后小于它的元素个数。

​ 归并排序所用到的辅助数组和新的数组都是该类型的数组

2.开始归并排序

3.合并两个有序数组,以下两种情况更新count

  • 当右子树遍历完成后
  • 当右指针所指的数大于左指针的数时(即右子树大于左子树时)
class Solution {
    private class Pair {
        int val, id;
        Pair(int val, int id) {
            // 记录数组的元素值
            this.val = val;
            // 记录元素在数组中的原始索引
            this.id = id;
        }
    }
    
    // 归并排序所用的辅助数组
    private Pair[] temp;
    // 记录每个元素后面比自己小的元素个数
    private int[] count;
    
    // 主函数
    public List<Integer> countSmaller(int[] nums) {
        int n = nums.length;
        count = new int[n];
        temp = new Pair[n];
        Pair[] arr = new Pair[n];
        // 记录元素原始的索引位置,以便在 count 数组中更新结果
        for (int i = 0; i < n; i++)
            arr[i] = new Pair(nums[i], i);
        
        // 执行归并排序,本题结果被记录在 count 数组中
        sort(arr, 0, n - 1);
        
        List<Integer> res = new LinkedList<>();
        for (int c : count) res.add(c);
        return res;
    }
    
    // 归并排序
    private void sort(Pair[] arr, int lo, int hi) {
        if (lo == hi) return;
        int mid = lo + (hi - lo) / 2;
        sort(arr, lo, mid);
        sort(arr, mid + 1, hi);
        merge(arr, lo, mid, hi);
    }
    
    // 合并两个有序数组
    private void merge(Pair[] arr, int lo, int mid, int hi) {
        for (int i = lo; i <= hi; i++) {
            temp[i] = arr[i];
        }
        
        int i = lo, j = mid + 1;
        for (int p = lo; p <= hi; p++) {
            if (i == mid + 1) {
                arr[p] = temp[j++];
            } else if (j == hi + 1) {
                arr[p] = temp[i++];
                // 更新 count 数组
                count[arr[p].id] += j - mid - 1;
            } else if (temp[i].val > temp[j].val) {
                arr[p] = temp[j++];
            } else {
                arr[p] = temp[i++];
                // 更新 count 数组
                count[arr[p].id] += j - mid - 1;
            }
        }
    }
}

二、493. 翻转对

    int count=0;
    int[] temp;
    public int reversePairs(int[] nums) {
        temp=new int[nums.length];
        sort(nums,0,nums.length-1);
        return count;
    }
    public void sort(int[] nums,int low,int high){
        if(low==high) return;
        int mid=low+(high-low)/2;
        sort(nums,low,mid);
        sort(nums,mid+1,high);
        merge(nums,low,mid,high);
    }
    public void merge(int[] nums,int low,int mid,int high){
        //if(low==high) return;
        for(int k=low;k<=high;k++){
            temp[k]=nums[k];
        }
        int i=0,j=0;
        
        //在合并前夹带私货
        for(i=low;i<=mid;i++){
            for(j=mid+1;j<=high;j++){
                if(temp[i]>temp[j]*2){
                    count++;
                }
            }
        }
        i=low;
        j=mid+1;
        for(int p=low;p<=high;p++){
            if(i==mid+1){
                nums[p]=temp[j++];
            }else if(j==high+1){
                nums[p]=temp[i++];
            }else if(temp[i]>temp[j]){
                nums[p]=temp[j++];
            }else{
                nums[p]=temp[i++];
            }
        }

    }

但是如果代码是这样写的话会出现下面截图的这种错误,因为nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kYn0RLge-1680231094015)(D:\Development\Typora\img\image-20230329112234000.png)]

因此修改为下面的代码

int count=0;
    int[] temp;
    public int reversePairs(int[] nums) {
        temp=new int[nums.length];
        sort(nums,0,nums.length-1);
        return count;
    }
    public void sort(int[] nums,int low,int high){
        if(low==high) return;
        int mid=low+(high-low)/2;
        sort(nums,low,mid);
        sort(nums,mid+1,high);
        merge(nums,low,mid,high);
    }
    public void merge(int[] nums,int low,int mid,int high){
        //if(low==high) return;
        for(int k=low;k<=high;k++){
            temp[k]=nums[k];
        }
        int i=0,j=0;
        //在合并前夹带私货
        for(i=low;i<=mid;i++){
            for(j=mid+1;j<=high;j++){
                // nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
                if((long)temp[i]>(long)temp[j]*2){
                    count++;
                }
            }
        }
        i=low;
        j=mid+1;
        for(int p=low;p<=high;p++){
            if(i==mid+1){
                nums[p]=temp[j++];
            }else if(j==high+1){
                nums[p]=temp[i++];
            }else if(temp[i]>temp[j]){
                nums[p]=temp[j++];
            }else{
                nums[p]=temp[i++];
            }
        }

    }

但是这个解法会出现超时问题,因为额外添加了一个嵌套for循环

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8H0AX7V6-1680231094017)(D:\Development\Typora\img\image-20230329112359188.png)]

那如何进行优化呢?

注意子数组 nums[low..mid] 是排好序的,也就是 nums[i] <= nums[i+1]

所以对于对于 nums[i], low <= i <= mid,我们在找到的符合 nums[i] > 2*nums[j]nums[j], mid+1 <= j <= high,也必然也符合 nums[i+1] > 2*nums[j]

也就是说,我们不用每次都傻乎乎地去遍历整个 nums[mid+1..high],只要维护一个开区间边界 end,维护 nums[mid+1..end-1] 是符合条件的元素即可

如何理解上面的那句话呢?

就是遍历到i时,end为end1,此时统计次数

下一次遍历到i+1时,end1若还满足nums[i+1]>nums[end1]*2,则继续end++,直到找到新的end2,退出循环统计次数

至于为什么在end1的基础上++,是因为nums[i] <= nums[i+1],这样就达到了不用重复遍历整个 nums[mid+1…high]

所以代码优化如下:

int count=0;
    int[] temp;
    public int reversePairs(int[] nums) {
        temp=new int[nums.length];
        sort(nums,0,nums.length-1);
        return count;
    }
    public void sort(int[] nums,int low,int high){
        if(low==high) return;
        int mid=low+(high-low)/2;
        sort(nums,low,mid);
        sort(nums,mid+1,high);
        merge(nums,low,mid,high);
    }
    public void merge(int[] nums,int low,int mid,int high){
        //if(low==high) return;
        for(int k=low;k<=high;k++){
            temp[k]=nums[k];
        }
        int i=0,j=0;
        //在合并前夹带私货
        // 进行效率优化,维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
        // 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
        //也就是说只要找到一个
        int end = mid + 1;
        for (i = low; i <= mid; i++) {
            // nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
            while (end <= high && (long)nums[i] > (long)nums[end] * 2) {
                end++;
            }
            count += end - (mid + 1);
        }
        i=low;
        j=mid+1;
        for(int p=low;p<=high;p++){
            if(i==mid+1){
                nums[p]=temp[j++];
            }else if(j==high+1){
                nums[p]=temp[i++];
            }else if(temp[i]>temp[j]){
                nums[p]=temp[j++];
            }else{
                nums[p]=temp[i++];
            }
        }

    }

三、327. 区间和的个数

提到区间和我们就要想到用前缀和数组,前缀和数组两个元素之差就是区间和

所以这道题就是构建一个前缀和数组,对前缀和数组进行归并排序,在合并之前对前缀和数组求区间和,这里的寻找区间有点像滑动窗口(让窗口中的元素和 nums[i] 的差落在 [lower, upper] 中)

不知道这段代码为什么不能得到正确答案?

class Solution {
    int count=0;
    int lower;
    int upper;
    long[] temp;
    public int countRangeSum(int[] nums, int lower, int upper) {
        //这里借助了前缀和数组,注意前缀和数组两个元素之差就是区间和
        //因此我们要构建出来一个前缀和,对前缀和数组进行归并排序,在合并之前
        //对前缀和数组求区间和,这里的寻找有点像滑动窗口
        this.lower=lower;
        this.upper=upper;
        //1.构建前缀和数组
        int n=nums.length;
        long[] preSum=new long[n+1];
        for(int i=0;i<n;i++){
            preSum[i+1]=preSum[i]+(long)nums[i];
        }
        //2.对前缀和数组进行归并排序
        //2.1声明辅助数组
        temp=new long[n];
        sort(preSum,0,n-1);
        return count;
    }
    //3.完成归并排序
    public void sort(long[] preSum,int low,int high){
        if(low==high) return;
        int mid=low+(high-low)/2;
        sort(preSum,low,mid);
        sort(preSum,mid+1,high);
        merge(preSum,low,mid,high);
    }
    //4.合并
    public void merge(long[] preSum,int low,int mid,int high){
        for(int j=low;j<=high;j++){
            temp[j]=preSum[j];
        }
        //5.在合并之前找到合适的区间和并统计
        int start=mid+1;//start表示第一个满足区间和范围的前缀和到preSum[i]
        int end=mid+1;//end表示最后一个满足的
        //那他们之间的差就是个数
        // 维护左闭右开区间 [start, end) 中的元素落在 [lower, upper] 中
        for(int i=low;i<=mid;i++){
            while(start<=high&&preSum[start]-preSum[i]<lower){
                //不在区间范围内则继续往后找
                start++;
            }
            while(end<=high&&preSum[end]-preSum[i]<=upper){
                //在范围内则继续找还有没有更大范围
                end++;
            }
            count+=end-start;
        }
        //6.开始正式合并
        int i=low,j=mid+1;
        for(int p=low;p<=high;p++){
            if(i==mid+1){
                preSum[p]=temp[j++];
            }else if(j==high+1){
                preSum[p]=temp[i++];
            }else if(temp[i]>temp[j]){
                preSum[p]=temp[j++];
            }else{
                preSum[p]=temp[i++];
            }
        }
    }
}

正确答案:

class Solution {
    int lower, upper;

    public int countRangeSum(int[] nums, int lower, int upper) {
        this.lower = lower;
        this.upper = upper;
        long[] preSum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; i++) {
            preSum[i + 1] = (long) nums[i] + preSum[i];
        }
        sort(preSum);
        return count;
    }

    // 用于辅助合并有序数组
    private long[] temp;
    private int count = 0;

    public void sort(long[] nums) {
        // 先给辅助数组开辟内存空间
        temp = new long[nums.length];
        // 排序整个数组(原地修改)
        sort(nums, 0, nums.length - 1);
    }

    // 定义:将子数组 nums[lo..hi] 进行排序
    private void sort(long[] nums, int lo, int hi) {
        if (lo == hi) {
            // 单个元素不用排序
            return;
        }
        // 这样写是为了防止溢出,效果等同于 (hi + lo) / 2
        int mid = lo + (hi - lo) / 2;
        // 先对左半部分数组 nums[lo..mid] 排序
        sort(nums, lo, mid);
        // 再对右半部分数组 nums[mid+1..hi] 排序
        sort(nums, mid + 1, hi);
        // 将两部分有序数组合并成一个有序数组
        merge(nums, lo, mid, hi);
    }

    // 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
    private void merge(long[] nums, int lo, int mid, int hi) {
        // 先把 nums[lo..hi] 复制到辅助数组中
        // 以便合并后的结果能够直接存入 nums
        for (int i = lo; i <= hi; i++) {
            temp[i] = nums[i];
        }

        // 这段代码会超时
        // for (int i = lo; i <= mid; i++) {
        //     // 在区间 [mid + 1, hi] 中寻找 lower <= delta <= upper 的元素
        //     for (int k = mid + 1; k <= hi; k++) {
        //         long delta = nums[k] - nums[i];
        //         if (delta <= upper && delta >= lower) {
        //             count++;
        //         }
        //     }
        // }

        // 进行效率优化
        // 维护左闭右开区间 [start, end) 中的元素落在 [lower, upper] 中
        int start = mid + 1, end = mid + 1;
        for (int i = lo; i <= mid; i++) {
            while (start <= hi && nums[start] - nums[i] < lower) {
                start++;
            }
            while (end <= hi && nums[end] - nums[i] <= upper) {
                end++;
            }
            count += end - start;
        }

        // 数组双指针技巧,合并两个有序数组
        int i = lo, j = mid + 1;
        for (int p = lo; p <= hi; p++) {
            if (i == mid + 1) {
                // 左半边数组已全部被合并
                nums[p] = temp[j++];
            } else if (j == hi + 1) {
                // 右半边数组已全部被合并
                nums[p] = temp[i++];
            } else if (temp[i] > temp[j]) {
                nums[p] = temp[j++];
            } else {
                nums[p] = temp[i++];
            }
        }
    }
}

总结

所有递归的算法,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码。你要写递归算法,本质上就是要告诉每个节点需要做什么

对于这个专题的归并排序算法,递归的 sort 函数就是二叉树的遍历函数,而 merge 函数就是在每个节点上做的事情

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值