详解Arrays.sort(T[] a, Comparator<? super T> c) Lambda版

详解Arrays.sort(T[] a, Comparator<? super T> c) Lambda版


写在开头

public static void main(String[] args) {
 String[] arr = new String[]{"aa", "a", "aaa"};
 // 根据字符串的长度排序
 Arrays.sort(arr, (a, b) -> a.length() - b.length());

 System.out.println(Arrays.toString(arr)); // 输出:[a, aa, aaa]
}

今天在写根据字符串长度排序的时候,突然很好奇sort()方法底层是如何实现的,于是便有了接下来的故事。

1. public static <T> void sort(T[] a, Comparator<? super T> c)

在Arrays工具类中,包含了许多sort的重载方法,如public static void sort(float[] a)public static void sort(int[] a)等等,可能有些同学就好奇了,这不是有泛型的方法吗,还要写重载这么多基础类型的干嘛呢?如果你也有这个念头,那么恭喜你,你比他们多获得了一个知识点:

  • 泛型的数据类型不能为基础类型

言归正传,我们回到该方法的源码:

public static <T> void sort(T[] a, Comparator<? super T> c) {
    if (c == null) {
        sort(a);
    } else {
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a, c);
        else
            TimSort.sort(a, 0, a.length, c, null, 0, 0);
    }
}

可以看到,如果传入的Comparator参数为空时,则会调用sort方法:

public static void sort(Object[] a) {
    // 这里是为了兼顾老版本的排序,需要用户在运行时设置参数,jdk表示会在之后废弃该该方法
    if (LegacyMergeSort.userRequested)
        // 传统归并排序
        legacyMergeSort(a);
    else
        ComparableTimSort.sort(a, 0, a.length, null, 0, 0);
}
2. private static void legacyMergeSort(Object[] a)

虽然这个方法目前来说用不到,但是我们本着认真负责的态度,还是探究一下源码。

legacyMergeSort顾名思义,传统归并排序

private static void legacyMergeSort(Object[] a) {
    Object[] aux = a.clone();
    mergeSort(aux, a, 0, a.length, 0);
}

可以看到,legacyMergeSort方法先是克隆了一份原数组,然后在对克隆的副本进行归并排序

private static void mergeSort(Object[] src,  // 克隆的副本
                              Object[] dest, // 目标排序数组
                              int low, // 0
                              int high, // a.length()
                              int off) {
    int length = high - low;

    // 插入排序: 当数组的长度小于 7 时,则直接使用插入排序并返回
    if (length < INSERTIONSORT_THRESHOLD) {
        for (int i=low; i<high; i++)
            for (int j=i; j>low &&
                 ((Comparable) dest[j-1]).compareTo(dest[j])>0; j--)
                swap(dest, j, j-1);
        return;
    }

    // 将 dest 的一半进行递归排序到 src
    int destLow  = low;
    int destHigh = high;
    low  += off; // 因为 off 传入的是0, 因此可以忽略
    high += off; // 同上
    int mid = (low + high) >>> 1;  // 无符号数右移,因为 low + high >= 0,因此等价于 (low + high)>> 1 等价于 (low + high)/ 2
    mergeSort(dest, src, low, mid, -off); // 归并排序的分治思想,分别对左、右两个子数组递归排序,直到满足长度小于7时,使用插入排序将子数组进行排序。
    mergeSort(dest, src, mid, high, -off);

    // 如果数组已经是有序的,则直接将src复制到dest即可
    if (((Comparable)src[mid-1]).compareTo(src[mid]) <= 0) {
        System.arraycopy(src, low, dest, destLow, length);
        return;
    }

    // 将已经排序好的一半合并到dest中
    // 逻辑很简单,就是将两个已经排序好的子数组合并,举个例子
    // 子数组src [2, 5, 3, 4], p = low = 0, q = mid = 2, high = 4
    // p = 0 < q = 2 , 并且 src[0] = 2 < src[2] = 3
    // 所以 dest[0] = src[p++] ==> dest[0] = src[p] = src[0] = 2, p = p + 1 = 1
    // p = 1 < q = 2 , 并且 src[1] = 5 > src[2] = 3
    // 所以 dest[1] = src[q++] ==> dest[1] = src[q] = src[2] = 3, q = q + 1 = 3
    // ......
    // dest = [2, 3, 4, 5]
    for(int i = destLow, p = low, q = mid; i < destHigh; i++) {
        if (q >= high || p < mid && ((Comparable)src[p]).compareTo(src[q])<=0)
            dest[i] = src[p++];
        else
            dest[i] = src[q++];
    }
}
3. static void sort(Object[] a, int lo, int hi, Object[] work, int workBase, int workLen)

分析完传统归并方法,接下来进入jdk默认的方法ComparableTimSort类下的sort,源码如下:

static void sort(Object[] a, // 目标排序数组
                 int lo, // 0   第一个需要被排序元素的索引
                 int hi, // a.length   最后一个需要被排序的元素索引
                 Object[] work, // null   工作数组
                 int workBase, // 0   工作数组可用空间
                 int workLen) { // 0   工作数组已用大小
  	// 断言:检查参数可用性,结果为false则抛出异常。
    assert a != null && lo >= 0 && lo <= hi && hi <= a.length;

    int nRemaining  = hi - lo;
    if (nRemaining < 2)
        return;  // 数组长度小于2 直接返回 无需排序

    // 如果数组长度小于 32 ,则使用 mini-TimSort
    if (nRemaining < MIN_MERGE) {
        // 见 3.1 源码分析
        int initRunLen = countRunAndMakeAscending(a, lo, hi);
        // 见 3.2 源码分析
        binarySort(a, lo, hi, lo + initRunLen);
        return;
    }
    /**
     * March over the array once, left to right, finding natural runs,
     * extending short natural runs to minRun elements, and merging runs
     * to maintain stack invariant.
     */
    ComparableTimSort ts = new ComparableTimSort(a, work, workBase, workLen);
    // 返回最小可运行长度 
    // a) 如果数组大小为2的N次幂,则返回16(MIN_MERGE / 2);
  // b) 其他情况下,逐位向右位移(即除以2),直到找到介于16和32间的一个数;
    int minRun = minRunLength(nRemaining);
    do {
        // 最大连续递增序列 见 3.1 源码分析
        int runLen = countRunAndMakeAscending(a, lo, hi);

        // If run is short, extend to min(minRun, nRemaining)
        // 如果 最大连续递增序列的长度 < 最小可运行长度,先进行二分排序 使得 runLen == minRun
        if (runLen < minRun) {
            int force = nRemaining <= minRun ? nRemaining : minRun;
            binarySort(a, lo, lo + force, lo + runLen);
            runLen = force;
        }

        // 将runLen 放入stack 满足条件时合并
        ts.pushRun(lo, runLen);
        // 见 3.3 源码分析
        ts.mergeCollapse();

        // Advance to find next run
        lo += runLen;
        nRemaining -= runLen;
    } while (nRemaining != 0);

    // 合并剩下的 run
    assert lo == hi;
    ts.mergeForceCollapse();
    assert ts.stackSize == 1;
}
3.1 private static int countRunAndMakeAscending(Object[] a, int lo, int hi)

countRunAndMakeAscending方法的作用是从数组开始处找到数组中最大连续递增或递减序列长度,如果为递减序列,还需要反转该序列。

private static int countRunAndMakeAscending(Object[] a, int lo, int hi) {
    assert lo < hi;
    int runHi = lo + 1;
    if (runHi == hi)
        return 1;

    // Find end of run, and reverse range if descending
    // tips: ((Comparable) a[runHi++]).compareTo(a[lo]) ==> ((Comparable) a[runHi]).compareTo(a[lo]); runHi++;
    // 逻辑很简单,比较数组a中第二个元素与第一个元素的大小,如果 2 > 1 则 为递增,反之 2 < 1 则为递减。
    if (((Comparable) a[runHi++]).compareTo(a[lo]) < 0) { // 递减
        // 循环找到最大递减序列
        while (runHi < hi && ((Comparable) a[runHi]).compareTo(a[runHi - 1]) < 0)
            runHi++;
        // 反转前 runHi 个元素,使其为递增序列
        // 反转逻辑也很简单,就是交换第一个元素a[low]和左后一个元素a[high],然后 low++, high--循环至low >= high
        reverseRange(a, lo, runHi);
    } else {                              // 递增
        while (runHi < hi && ((Comparable) a[runHi]).compareTo(a[runHi - 1]) >= 0)
            runHi++;
    }

    return runHi - lo;
}
3.2 private static void binarySort(Object[] a, int lo, int hi, int start)

通过二分查找的方式,循环将所有未排序的数组元素插入到有序的数组中。

private static void binarySort(Object[] a, int lo, int hi, int start) {
    assert lo <= start && start <= hi;
    if (start == lo)
        start++;
    // 循环将 未排序 的数组元素通过 二分查找 的方式找到在有序序列中对应位置插入 即二分排序。
    for ( ; start < hi; start++) {
        // 待排序元素
        Comparable pivot = (Comparable) a[start];

        // Set left (and right) to the index where a[start] (pivot) belongs
        int left = lo;
        int right = start;
        assert left <= right;
        /*
         * Invariants:
         *   pivot >= all in [lo, left).
         *   pivot <  all in [right, start).
         */
        // 二分查找(二分查找为基础算法,这里就不多赘述,有兴趣的小伙伴可以自己查阅资料)
        while (left < right) {
            int mid = (left + right) >>> 1;
            if (pivot.compareTo(a[mid]) < 0)
                right = mid;
            else
                left = mid + 1;
        }
        assert left == right;

        /*
         * The invariants still hold: pivot >= all in [lo, left) and
         * pivot < all in [left, start), so pivot belongs at left.  Note
         * that if there are elements equal to pivot, left points to the
         * first slot after them -- that's why this sort is stable.
         * Slide elements over to make room for pivot.
         */
        int n = start - left;  // 需要移动元素的个数
        // 这里switch的作用是对 System.arraycopy()方法的的优化,
        // 注意因为 case 2未添加break,因此如果 n == 2,会同时执行case 2 和case 1.
        switch (n) {
            case 2:  a[left + 2] = a[left + 1];
            case 1:  a[left + 1] = a[left];
                break;
            // 此方法为 native 方法,各参数代表的意思如下
            default: System.arraycopy(a, left, a, left + 1, n);
        }
        a[left] = pivot;
    }
}

public static native void arraycopy(Object src,  // 原数组
                                    int  srcPos, // 原数组的起始位置
                                    Object dest, // 目标数组
                                    int destPos, // 目标数组的起始位置
                                    int length); // 需要copy的数组长度
3.3 private void mergeCollapse()

本方法的作用也是优化合并,后两个分区的和大于前一个分区,则中间的分区与最小的分区先合并,否则合并后连个分区

private void mergeCollapse() {
    while (stackSize > 1) {
        int n = stackSize - 2;
        if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1] ||
            n > 1 && runLen[n-2] <= runLen[n] + runLen[n-1]) {
            if (runLen[n - 1] < runLen[n + 1])
                n--;
        } else if (n < 0 || runLen[n] > runLen[n + 1]) {
            break; // Invariant is established
        }
        mergeAt(n);
    }
}
小结

上面介绍了Arrays.sort()方法当 Comparator == null 时的源码,其实当通过观察源码我们会发现TimSort的代码与ComparableTimSort几乎一样,有兴趣的小伙伴可以自己去看看。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值