什么是归并排序?
- 首先归并排序是10大排序中最常用面试常考的,时间复杂度是nlogn,其次我们学到的二叉树中后序遍历实际上就是归并排序来的。
- 归并排序,先对左右两部分的进行排好序,在进行合并,二叉树的后序遍历也是先将左右子树遍历好后,在进行操作,比如二叉树的最大深度:先递归遍历对应二叉子树的最大深度在进行叠加
- 接下来看看归并排序的数组模板
public class mergsort2 {
public static void main(String[] args) {
int[] arr = {3, 4, 6, 8, 2, 5, 7, 8, 9, 11, 1, 34, 22};
sort(arr, 0, arr.length - 1);
print(arr);
}
static void sort(int[] arr, int left, int right) {
if (left == right) {
return;
}
int mid = left + (right - left) / 2;
sort(arr, left, mid);
sort(arr, mid + 1, right);
merge(arr, left, mid + 1, right);
}
public static void merge(int[] arr, int leftPtr, int rightPtr, int rightBound) {
int[] temp = new int[rightBound - leftPtr + 1];
int mid = rightPtr - 1;
int i = leftPtr;
int j = rightPtr;
int k = 0;
while (i <= mid && j <= rightBound) {
temp[k++] = arr[i] <= arr[j] ? arr[i++] : arr[j++];
}
while (i <= mid) {
temp[k++] = arr[i++];
}
while ((j <= rightBound)) {
temp[k++] = arr[j++];
}
for (int l = 0; l < temp.length; l++) {
arr[leftPtr + l] = temp[l];
}
}
static void print(int[] arr) {
for (int i = 0; i < arr.length; i++) {
System.out.print(arr[i] + " ");
}
}
}
int maxDepth(TreeNode root){
if(root==null){
return 0;
}
int leftMaxDepth = maxDepth(root.left);
int rightMaxDpeth = maxDpeth(root.right);
int res = Math.max(leftMaxDpeth,rightMaxDepth)+1;
return res;
}
- 因此,归并排序的过程可以在逻辑上抽象成一颗二叉树,书上每个节点的值可以认为是nums[lo…hi],叶子节点就是数组中的每个元素
![在这里插入图片描述](https://img-blog.csdnimg.cn/e08e77bcad72467fa80bd51a20ea17d3.png)
接下来我们看几个题,来进一步加深对归并的理解
1.1 912排序数组
![在这里插入图片描述](https://img-blog.csdnimg.cn/81923f57983848a2af8fa82dd270d3b4.png)
class Solution {
public int[] sortArray(int[] nums) {
Merge.sort(nums);
return nums;
}
class Merge{
private static int[]temp;
public static void sort(int[]nums){
temp = new int[nums.length];
sort(nums,0,nums.length-1);
}
private static void sort(int[]nums,int lo,int hi){
if(lo==hi){
return;
}
int mid = lo+(hi-lo)/2;
sort(nums,lo,mid);
sort(nums,mid+1,hi);
merge(nums,lo,mid,hi);
}
private static void merge(int[]nums,int lo,int mid,int hi){
for(int i = lo;i<=hi;i++){
temp[i] = nums[i];
}
int i = lo, j = mid+1;
for(int p = lo;p<=hi;p++){
if(i==mid+1){
nums[p]=temp[j++];
}else if(j==hi+1){
nums[p]=temp[i++];
}else if(temp[i]>temp[j]){
nums[p] = temp[j++];
}else{
nums[p]= temp[i++];
}
}
}
}
}
1.2 315计算右侧小于当前元素的个数
- 我们知道归并排序的左右两个是排好序的,我们只需要记录小于当前的数的数量就可以,因此这个怎么记录就是一个难点了,这里需要申明一个类,Pair,构造中声明对应的val和index值,因此当我们遇到小于当前的值得时候可以利用索引下标进行相减
- 我们只需要在归并排序当中进行改动几行就可
![在这里插入图片描述](https://img-blog.csdnimg.cn/b5c473cf5e3e4343899ef6a9ec959729.png)
class Solution {
private class Pair{
int val,id;
Pair(int val,int id){
this.val = val;
this.id = id;
}
}
private Pair[]temp;
private int[]count;
public List<Integer> countSmaller(int[] nums) {
int n = nums.length;
temp = new Pair[n];
count = new int[n];
Pair[]arr = new Pair[n];
for(int i = 0;i<n;i++){
arr[i] = new Pair(nums[i],i);
}
sort(arr,0,nums.length-1);
LinkedList<Integer>res = new LinkedList<>();
for(int a:count){
res.add(a);
}
return res;
}
private void sort(Pair[]arr,int lo,int hi){
if(lo==hi){
return;
}
int mid = lo + (hi-lo)/2;
sort(arr,lo,mid);
sort(arr,mid+1,hi);
merge(arr,lo,mid,hi);
}
private void merge(Pair[]arr,int lo,int mid,int hi){
for(int i = lo;i<=hi;i++){
temp[i] = arr[i];
}
int i = lo,j = mid+1;
for(int p =lo;p<=hi;p++){
if(i == mid+1){
arr[p] = temp[j++];
}else if(j==hi+1){
arr[p] = temp[i++];
count[arr[p].id] += j-mid-1;
}else if(temp[i].val > temp[j].val){
arr[p] = temp[j++];
}else{
arr[p] = temp[i++];
count[arr[p].id] += j-mid-1;
}
}
}
}
1.3 翻转对
![在这里插入图片描述](https://img-blog.csdnimg.cn/4020d5eb8e3145f6a69a84abc809fc3b.png)
class Solution {
private int[]temp;
private int count = 0;
public int reversePairs(int[] nums) {
sort(nums);
return count;
}
public void sort(int[]nums){
temp = new int[nums.length];
sort(nums,0,nums.length-1);
}
private void sort(int[]nums,int lo,int hi){
if(lo==hi){
return;
}
int mid = lo + (hi-lo)/2;
sort(nums,lo,mid);
sort(nums,mid+1,hi);
merge(nums,lo,mid,hi);
}
private void merge(int[]nums,int lo,int mid,int hi){
for(int i = lo;i<=hi;i++){
temp[i] = nums[i];
}
int end = mid+1;
for(int i = lo;i<=mid;i++){
while(end<=hi && (long)nums[i]>(long)nums[end]*2){
end++;
}
count += end-(mid+1);
}
int i =lo,j=mid+1;
for(int p = lo;p<=hi;p++){
if(i==mid+1){
nums[p] = temp[j++];
}else if(j==hi+1){
nums[p] = temp[i++];
}else if(temp[i]>temp[j]){
nums[p] = temp[j++];
}else{
nums[p] = temp[i++];
}
}
}
}