归并排序 -- 多线程版本

71 篇文章 1 订阅
该博客介绍了使用Java实现多线程版本的归并排序,通过划分数组、并发排序和合并操作提高效率。作者通过线程池、代理模式和并发工具类展示了不同排序策略,包括使用FutureTask、CountDownLatch和PriorityQueue进行结果合并。并通过测试验证了排序的正确性和性能提升,结果显示在某些场景下性能可提升4至8倍。
摘要由CSDN通过智能技术生成

近期学了Java多线程的一些关键字还有设计模式的代理模式,在这应用一下, 写一个多线程版本的归并排序。先写下版本1,以后在逐渐进行改进。

环境

  • jdk 8+
  • maven
  • cglib

思路

这里主要实现了三个类。

  1. 进行数组生成和打印以及其他数组操作的SortUtil类。
  2. 进行时间统计的代理类TimeProxy,使用cglib生成代理类。
  3. 进行主逻辑实现的SortThread类。

归并的核心思路

  1. 划分数组,分成两部分,直到划分到小于等于给定len。
  2. 对于每一个划分,开启两个线程
  3. 在两个线程执行完成之后, 进行合并。

代码

  • SortUtil类
package threadBase.threadPool;

import java.util.Random;

/**
 * @author: Zekun Fu
 * @date: 2022/5/21 11:07
 * @Description: 使用线程池进行归并排序
 *
 *
 * 设计:
 * 1. 开启16个线程进行排序。
 * 2. 进行划分
 * 3. 使用快速排序进行排序。
 * 4. 把16个结果放到堆里面,取出一个,然后维护堆和原本的堆就行了。
 *
 *
 * 会快多少呢?
 * 1. 首先16个线程并行运行,快16倍
 * 2. 其次,每一个线程计算的时间复杂度是O(n/16 * log(n/16))。
 * 3. 最后最进行合并的时候,需要O(nlog16)的时间复杂度。
 * 4. 运行时间基本上可以看作是线性的。
 *
 *
 *
 *
 */
public class SortUtils {
    private int[] tmp;

    void merge_sort(int[] arr, int l, int r) {
        if (l >= r) return ;
        int mid = (l + r) >> 1;
        merge_sort(arr, l, mid);            // [l, mid]进行排序
        merge_sort(arr, mid + 1, r);      // [mid + 1, r]进行排序

        int i = l, j = mid + 1, k = l;
        while(i <= mid && j <= r) {         // 选择较小的
            tmp[k++] = arr[i] <= arr[j] ? arr[i++] : arr[j++];
        }
        while(i <= mid) tmp[k++] = arr[i++];
        while(j <= r) tmp[k++] = arr[j++];

        for (int t = l; t <= r; t++) arr[t] = tmp[t];
    }

    public static void printArr(int[] a) {
        for (int x: a) System.out.print(x + " ");
        System.out.println();
    }
    public static void printArr(int[][]arr) {
        for (int[] tmp : arr) {
            printArr(tmp);
        }
    }

    public static int[] getArr(int n, int m) {
        Random random = new Random();
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) arr[i] = random.nextInt(m);
        return arr;
    }

    /*'
    * 判断两个数组是否相等
    * */
    public static boolean check(int[] arr1, int[] arr2) {
        for (int i = 0; i < arr1.length; i++) {
            if (arr1[i] != arr2[i]) return false;
        }
        return true;
    }

    public static void main(String[] args) {
        int n = 10;
        SortUtils th = new SortUtils();
        th.tmp = new int[n];
        int[] arr = getArr(n, 100);
        printArr(arr);
        th.merge_sort(arr, 0, n - 1);
        printArr(arr);


    }

}

  • SortThread类
package threadBase.threadPool;

import javafx.util.Pair;
import net.sf.cglib.proxy.Enhancer;

import javax.xml.transform.Source;
import java.lang.reflect.Array;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author: Zekun Fu
 * @date: 2022/5/21 14:41
 * @Description: 分治排序的主线程
 *
 * 对于自动排序: 需要指定线程数量和每个线程处理数组的大小
 * 对于自动排序的2:需要执行线程数量和.. 以及分配一个和原数组一样大的tmp数组。
 *
 */
public class SortThread {


    private ExecutorService pool;           // 线程池

    public SortThread() {
        pool = Executors.newCachedThreadPool();        // 因为8核心就只有16个线程
    }

    private int[][] div(int[] arr, int d) {  // 数组划分成多少份
        int [][] ans = new int[d][];
        int n = arr.length;
        int len = n / d;
        for (int i = 0; i < d; i++) {
            int tmp[];
            if(i != d - 1) tmp = new int[len];
            else tmp = new int[n - (d - 1) * len];
            for (int j = i * len; j < (i + 1) * len; j++) {
                tmp[j - (i * len)] = arr[j];
            }
            ans[i] = tmp;
        }
        for (int i = d * len, j = len; i < n; i++, j++) {
            ans[d - 1][j] = arr[i];
        }
        return ans;
    }


    /*
    *   使用分治进行,每次分成两个线程进行排序,然后吧排好序的进行合并。
    * 这里进行数组的划分,进行分支。
    * */
    private int[] sortByAuto(int[] arr, int l, int r, int len) throws Exception{
        if (r - l + 1 <= len) {
            int []tmp = Arrays.copyOfRange(arr, l, r + 1);
            Arrays.sort(tmp);
            return tmp;
        }
        int mid = (l + r) >> 1;

        FutureTask<int[]> task1 = new FutureTask<int[]>(()->{
            int[] ans = sortByAuto(arr, l, mid, len);
            return ans;
        });
        FutureTask<int[]>task2 = new FutureTask<int[]>(()->{
            int[] ans = sortByAuto(arr, mid + 1, r, len);
            return ans;
        });
        new Thread(task1).start();
        new Thread(task2).start();

        // 由于是阻塞式的,所以再返回之前一定是已经完成了这两组的排序了
        int[] tmpl = task1.get();
        int[] tmpr = task2.get();
        int[] ans = new int[r - l + 1];
        int i = 0, j = 0, k = 0;
        while(i < tmpl.length && j < tmpr.length) {
            ans[k++] = tmpl[i] <= tmpr[j] ? tmpl[i++] : tmpr[j++];
        }
        while(i < tmpl.length) {
            ans[k++] = tmpl[i++];
        }
        while(j < tmpr.length) {
            ans[k++] = tmpr[j++];
        }
        return ans;
    }
    /*
     *
     * 不进行复制粘贴,直接使用原数组进行排序
     * */
    private void sortByAuto2(int[] arr, int l, int r, int len, int[] tmp) {
        if (r - l + 1 <= len) {
            Arrays.sort(arr, l, r + 1);
            return;
        }
        int mid = (l + r) >> 1;
        CountDownLatch latch = new CountDownLatch(2);
        new Thread(()->{
            sortByAuto2(arr, l, mid, len, tmp);
            latch.countDown();
        }).start();
        new Thread(()->{
            sortByAuto2(arr, mid + 1, r, len, tmp);
            latch.countDown();
        }).start();

        try{
            latch.await();
        } catch (Exception e) {
            e.printStackTrace();
        }
        int i = l, j = l, k = mid + 1;
        while(j <= mid && k <= r) {
            tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
        }
        while(j <= mid) {
            tmp[i++] = arr[j++];
        }
        while(k <= r) {
            tmp[i++] = arr[k++];
        }
        for (i = l; i <= r; i++) arr[i] = tmp[i];
    }

        /*
    *
    *   直接采用分治的方式进行
    * 是一个错误的方法不知道怎么实现同步
    * */
    @Deprecated                 // 错误方法,别用了
    private int[] sortByAuto3(int[] arr, int l, int r, int len) {


        // 1. 递归基
        int[] tmp = new int[r - l + 1];
        for (int i = l, j = 0; i <= r; i++, j++) {
            tmp[j] = arr[i];
        }
        if (r - l + 1 <= len) {
            new Thread(()->{
                Arrays.sort(tmp);
            }).start();
            return tmp;
        }
        // 2.进行左右划分
        int mid = (l + r) >> 1;
        int[] tmpl = sortByAuto3(arr, l, mid, len);
        int[] tmpr = sortByAuto3(arr, mid + 1, r, len);

        // 3. 进行合并, 必须等待所有的子线程排序完成。
        int[] merge = new int[r - l + 1];
        int i = l, j = 0, k = 0;
        while(j < tmpl.length && k < tmpr.length) {
            merge[i - l] = tmpl[j] <= tmpr[k] ? tmpl[j++] : tmpr[k++];
            i++;
        }
        while(j < tmpl.length) {
            merge[i - l] = tmpl[j++];
            i++;
        }
        while(k < tmpr.length) {
            merge[i - l] = tmpr[k++];
            i++;
        }
        return merge;
    }



    /*
    *
    *   使用线程池进行数组的排序
    * */
    private void sortByPool(int[] arr, int l, int r, int len, int[] tmp) {
        if (r - l <= len) {
            Arrays.sort(arr, l, r + 1);
            return ;
        }
        // 1. 进行划分
        int mid = (l + r) >> 1;
        Future<?> f1 = pool.submit(() -> sortByPool(arr, l, mid, len, tmp));
        Future<?> f2 = pool.submit(()-> sortByPool(arr, mid + 1, r, len, tmp));
        // 2. 等待两个线程执行完成
        try {
            f1.get();
            f2.get();
        } catch (Exception e) {
            e.printStackTrace();
        }
        // 3. 进行合并即可。
        int i = l, j = l, k = mid + 1;
        while(j <= mid && k <= r) tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
        while(j <= mid) tmp[i++] = arr[j++];
        while(k <= r) tmp[i++] = arr[k++];
        for (i = l; i <= r; i++) arr[i] = tmp[i];
    }
    /*
    *
    *   使用手动线程进行数组的排序
    * */
    private int[] collect(int[][]sortedArr, int size) {

        class Node implements Comparable<Node>{
            int num;
            int ground;             // 第几组
            int id;                 // 第几个

            public Node(int num, int group, int id) {
                this.num = num;
                this.ground = group;
                this.id = id;
            }

            @Override
            public int compareTo(Node o) {
                return Integer.compare(this.num, o.num);
            }
        }

        int[] ans = new int[size];
        Queue<Node> heap = new PriorityQueue<>();
        int d = sortedArr.length;
        // 1.建立初始堆
        for (int i = 0; i < d; i++) {
            Node p = new Node(sortedArr[i][0], i, 0);
            heap.add(p);
        }
        // 2.取顶,放入答案,重新建堆。
        int idx = 0;
        while(!heap.isEmpty()) {
            Node top = heap.poll();
            ans[idx ++] = top.num;
            int g = top.ground, id = top.id;
            if (id  + 1 < sortedArr[g].length) {     // 看是否为空
                heap.add(new Node(sortedArr[g][id + 1], g, id + 1));
            }
        }
        return ans;
    }
    public int[] threadSortByFutT(int[] arr) {
        int n = 8;      // 线程数量, 小于等于数组的长度。
        int[][] divArr = div(arr, n);


        AtomicInteger id = new AtomicInteger();
        CountDownLatch latch = new CountDownLatch(n);
        for (int i = 0; i < n; i++) {
            new Thread(()->{
                Arrays.sort(divArr[id.getAndIncrement()]);  // 给数组进行排序
                latch.countDown();
            }).start();
        }
        // 这里必须等所有数组排好序才能继续执行
        try {
            latch.await();
        }catch (Exception e) {
            e.printStackTrace();
        }
//        System.out.println("拍好序之后的数组");
//        SortUtils.printArr(divArr);

        int[] ans = collect(divArr, arr.length);            // 得到一个新的数组
        return ans;
    }

    public void sortBySingle(int[] arr) {
        Arrays.sort(arr);
    }
    /*
    *
    * 为了方便代理的使用,
    * 这里写一个方法调用被代理对象的方法,
    * 然后执行的时间就是被代理对象的执行时间
    * */
    public int[] sortByAuto(SortThread s, int[] arr, int tn) {
        int n = arr.length;
        int len = n / tn;           // 每个线程执需要排序的长度
        int[] ans = null;
        try {
            ans = s.sortByAuto(arr, 0, n - 1, len); // 调用s的,防止递归调用打印。
        } catch (Exception e) {
            e.printStackTrace();
        }finally {
            return ans;
        }
    }

    public void sortByAuto2(SortThread s, int[] arr, int tn) {
        int n = arr.length;
        int len = n / tn;
        int[] tmp =  new int[n];
        s.sortByAuto2(arr, 0, n - 1, len, tmp);     // 调用s的
    }

    public void sortByPool(SortThread s, int[] arr, int tn) {
        int n = arr.length;
        int len = n / tn;
        int[] tmp = new int[n];
        s.sortByPool(arr, 0, n - 1, len, tmp); // 调用s的
        s.pool.shutdown();
    }


    // 进行辅助函数的测试
    public static void test() {
        int n = 16, m = 100;
        int[] arr = SortUtils.getArr(16, 100);
        int[] sortedArr = Arrays.copyOf(arr, n);
        Arrays.sort(sortedArr);
        SortThread s = new SortThread();
        SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);
        System.out.println("原数组为:");
        SortUtils.printArr(arr);
        System.out.println("排好序的数组为:");
        SortUtils.printArr(sortedArr);

        // 测试划分是否正确
        System.out.println("划分之后的数组为:");
        int[][] divArr = s.div(arr, 3);
        SortUtils.printArr(divArr);

        // 进行sortByauto的正确性测试
        int tn = 8;
        int[] ans1 = st.sortByAuto(s, arr, tn);
        SortUtils.printArr(ans1);
        System.out.println("sortByAuto是否正确:" + SortUtils.check(sortedArr, ans1));


        // 进行sortByAuto2的正确性测试
        tn = 8;
        int[] ans2 = Arrays.copyOf(arr, n);
        st.sortByAuto2(s, ans2, tn);
        SortUtils.printArr(ans2);
        System.out.println("sortByAuto2是否正确:" + SortUtils.check(sortedArr, ans2));

        // 进行线程池排序的测试
        tn = 8;
        int[] ans3 = Arrays.copyOf(arr, n);
        st.sortByPool(s, ans3, tn);
        SortUtils.printArr(ans3);
        System.out.println("线程池排序是否正确:" + SortUtils.check(sortedArr, ans3));

    }
    public static void runMethod() {
        int n = (int)1e8, m = (int)1e9;
        int[] arr = SortUtils.getArr(n, m);       // 生成数组
        SortThread s = new SortThread();
        SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);

        int[] sortArr = Arrays.copyOf(arr, n);
        st.sortBySingle(sortArr);                    // 进行一次标准排序,判断正误 + 统计时间

        // 测试a手动划分,然后使用堆进行合并的运行时间,这里使用了代理进行时间统计。
        int[] ans1 = st.threadSortByFutT(arr);     // 这里不对arr1打乱顺序
        System.out.println("threadSortByFutT答案是否正确:" + SortUtils.check(sortArr, ans1));

        //  测试方法1的正确性和运行时间
        int tn = 8;                      // 开启线程的个数
        int[] ans = st.sortByAuto(s, arr, tn);   // 不对arr改变顺序
        System.out.println("sortByAuto是否正确:" + SortUtils.check(ans, sortArr));

        // 测试方法2的正确性和运行的时间
        int[] ans2 = Arrays.copyOf(arr, n);       // 不改变原来的数组
        tn = 8;
        st.sortByAuto2(s, ans2, tn);
        System.out.println("sortByAuto2是否正确:" + SortUtils.check(ans2, sortArr));

        // 测试方法3的正确定和运行时间
        int[] ans3 = Arrays.copyOf(arr, n);
        tn = 8;
        st.sortByPool(s, ans3, tn);
        System.out.println("线程池的最大大小为;" + ((ThreadPoolExecutor)s.pool).getLargestPoolSize());
        System.out.println("sortByPool是否正确:" + SortUtils.check(ans3, sortArr));
    }
    public static void main(String[] args) {
//        test();
        runMethod();
    }
}

  • SortThread类
package threadBase.threadPool;

import javafx.util.Pair;
import net.sf.cglib.proxy.Enhancer;

import javax.xml.transform.Source;
import java.lang.reflect.Array;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author: Zekun Fu
 * @date: 2022/5/21 14:41
 * @Description: 分治排序的主线程
 *
 * 对于自动排序: 需要指定线程数量和每个线程处理数组的大小
 * 对于自动排序的2:需要执行线程数量和.. 以及分配一个和原数组一样大的tmp数组。
 *
 */
public class SortThread {


    private ExecutorService pool;           // 线程池

    public SortThread() {
        pool = Executors.newCachedThreadPool();        // 因为8核心就只有16个线程
    }

    private int[][] div(int[] arr, int d) {  // 数组划分成多少份
        int [][] ans = new int[d][];
        int n = arr.length;
        int len = n / d;
        for (int i = 0; i < d; i++) {
            int tmp[];
            if(i != d - 1) tmp = new int[len];
            else tmp = new int[n - (d - 1) * len];
            for (int j = i * len; j < (i + 1) * len; j++) {
                tmp[j - (i * len)] = arr[j];
            }
            ans[i] = tmp;
        }
        for (int i = d * len, j = len; i < n; i++, j++) {
            ans[d - 1][j] = arr[i];
        }
        return ans;
    }


    /*
    *   使用分治进行,每次分成两个线程进行排序,然后吧排好序的进行合并。
    * 这里进行数组的划分,进行分支。
    * */
    private int[] sortByAuto(int[] arr, int l, int r, int len) throws Exception{
        if (r - l + 1 <= len) {
            int []tmp = Arrays.copyOfRange(arr, l, r + 1);
            Arrays.sort(tmp);
            return tmp;
        }
        int mid = (l + r) >> 1;

        FutureTask<int[]> task1 = new FutureTask<int[]>(()->{
            int[] ans = sortByAuto(arr, l, mid, len);
            return ans;
        });
        FutureTask<int[]>task2 = new FutureTask<int[]>(()->{
            int[] ans = sortByAuto(arr, mid + 1, r, len);
            return ans;
        });
        new Thread(task1).start();
        new Thread(task2).start();

        // 由于是阻塞式的,所以再返回之前一定是已经完成了这两组的排序了
        int[] tmpl = task1.get();
        int[] tmpr = task2.get();
        int[] ans = new int[r - l + 1];
        int i = 0, j = 0, k = 0;
        while(i < tmpl.length && j < tmpr.length) {
            ans[k++] = tmpl[i] <= tmpr[j] ? tmpl[i++] : tmpr[j++];
        }
        while(i < tmpl.length) {
            ans[k++] = tmpl[i++];
        }
        while(j < tmpr.length) {
            ans[k++] = tmpr[j++];
        }
        return ans;
    }
    /*
     *
     * 不进行复制粘贴,直接使用原数组进行排序
     * */
    private void sortByAuto2(int[] arr, int l, int r, int len, int[] tmp) {
        if (r - l + 1 <= len) {
            Arrays.sort(arr, l, r + 1);
            return;
        }
        int mid = (l + r) >> 1;
        CountDownLatch latch = new CountDownLatch(2);
        new Thread(()->{
            sortByAuto2(arr, l, mid, len, tmp);
            latch.countDown();
        }).start();
        new Thread(()->{
            sortByAuto2(arr, mid + 1, r, len, tmp);
            latch.countDown();
        }).start();

        try{
            latch.await();
        } catch (Exception e) {
            e.printStackTrace();
        }
        int i = l, j = l, k = mid + 1;
        while(j <= mid && k <= r) {
            tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
        }
        while(j <= mid) {
            tmp[i++] = arr[j++];
        }
        while(k <= r) {
            tmp[i++] = arr[k++];
        }
        for (i = l; i <= r; i++) arr[i] = tmp[i];
    }

        /*
    *
    *   直接采用分治的方式进行
    * 是一个错误的方法不知道怎么实现同步
    * */
    @Deprecated                 // 错误方法,别用了
    private int[] sortByAuto3(int[] arr, int l, int r, int len) {


        // 1. 递归基
        int[] tmp = new int[r - l + 1];
        for (int i = l, j = 0; i <= r; i++, j++) {
            tmp[j] = arr[i];
        }
        if (r - l + 1 <= len) {
            new Thread(()->{
                Arrays.sort(tmp);
            }).start();
            return tmp;
        }
        // 2.进行左右划分
        int mid = (l + r) >> 1;
        int[] tmpl = sortByAuto3(arr, l, mid, len);
        int[] tmpr = sortByAuto3(arr, mid + 1, r, len);

        // 3. 进行合并, 必须等待所有的子线程排序完成。
        int[] merge = new int[r - l + 1];
        int i = l, j = 0, k = 0;
        while(j < tmpl.length && k < tmpr.length) {
            merge[i - l] = tmpl[j] <= tmpr[k] ? tmpl[j++] : tmpr[k++];
            i++;
        }
        while(j < tmpl.length) {
            merge[i - l] = tmpl[j++];
            i++;
        }
        while(k < tmpr.length) {
            merge[i - l] = tmpr[k++];
            i++;
        }
        return merge;
    }



    /*
    *
    *   使用线程池进行数组的排序
    * */
    private void sortByPool(int[] arr, int l, int r, int len, int[] tmp) {
        if (r - l <= len) {
            Arrays.sort(arr, l, r + 1);
            return ;
        }
        // 1. 进行划分
        int mid = (l + r) >> 1;
        Future<?> f1 = pool.submit(() -> sortByPool(arr, l, mid, len, tmp));
        Future<?> f2 = pool.submit(()-> sortByPool(arr, mid + 1, r, len, tmp));
        // 2. 等待两个线程执行完成
        try {
            f1.get();
            f2.get();
        } catch (Exception e) {
            e.printStackTrace();
        }
        // 3. 进行合并即可。
        int i = l, j = l, k = mid + 1;
        while(j <= mid && k <= r) tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
        while(j <= mid) tmp[i++] = arr[j++];
        while(k <= r) tmp[i++] = arr[k++];
        for (i = l; i <= r; i++) arr[i] = tmp[i];
    }
    /*
    *
    *   使用手动线程进行数组的排序
    * */
    private int[] collect(int[][]sortedArr, int size) {

        class Node implements Comparable<Node>{
            int num;
            int ground;             // 第几组
            int id;                 // 第几个

            public Node(int num, int group, int id) {
                this.num = num;
                this.ground = group;
                this.id = id;
            }

            @Override
            public int compareTo(Node o) {
                return Integer.compare(this.num, o.num);
            }
        }

        int[] ans = new int[size];
        Queue<Node> heap = new PriorityQueue<>();
        int d = sortedArr.length;
        // 1.建立初始堆
        for (int i = 0; i < d; i++) {
            Node p = new Node(sortedArr[i][0], i, 0);
            heap.add(p);
        }
        // 2.取顶,放入答案,重新建堆。
        int idx = 0;
        while(!heap.isEmpty()) {
            Node top = heap.poll();
            ans[idx ++] = top.num;
            int g = top.ground, id = top.id;
            if (id  + 1 < sortedArr[g].length) {     // 看是否为空
                heap.add(new Node(sortedArr[g][id + 1], g, id + 1));
            }
        }
        return ans;
    }
    public int[] threadSortByFutT(int[] arr) {
        int n = 8;      // 线程数量, 小于等于数组的长度。
        int[][] divArr = div(arr, n);


        AtomicInteger id = new AtomicInteger();
        CountDownLatch latch = new CountDownLatch(n);
        for (int i = 0; i < n; i++) {
            new Thread(()->{
                Arrays.sort(divArr[id.getAndIncrement()]);  // 给数组进行排序
                latch.countDown();
            }).start();
        }
        // 这里必须等所有数组排好序才能继续执行
        try {
            latch.await();
        }catch (Exception e) {
            e.printStackTrace();
        }
//        System.out.println("拍好序之后的数组");
//        SortUtils.printArr(divArr);

        int[] ans = collect(divArr, arr.length);            // 得到一个新的数组
        return ans;
    }

    public void sortBySingle(int[] arr) {
        Arrays.sort(arr);
    }
    /*
    *
    * 为了方便代理的使用,
    * 这里写一个方法调用被代理对象的方法,
    * 然后执行的时间就是被代理对象的执行时间
    * */
    public int[] sortByAuto(SortThread s, int[] arr, int tn) {
        int n = arr.length;
        int len = n / tn;           // 每个线程执需要排序的长度
        int[] ans = null;
        try {
            ans = s.sortByAuto(arr, 0, n - 1, len); // 调用s的,防止递归调用打印。
        } catch (Exception e) {
            e.printStackTrace();
        }finally {
            return ans;
        }
    }

    public void sortByAuto2(SortThread s, int[] arr, int tn) {
        int n = arr.length;
        int len = n / tn;
        int[] tmp =  new int[n];
        s.sortByAuto2(arr, 0, n - 1, len, tmp);     // 调用s的
    }

    public void sortByPool(SortThread s, int[] arr, int tn) {
        int n = arr.length;
        int len = n / tn;
        int[] tmp = new int[n];
        s.sortByPool(arr, 0, n - 1, len, tmp); // 调用s的
        s.pool.shutdown();
    }


    // 进行辅助函数的测试
    public static void test() {
        int n = 16, m = 100;
        int[] arr = SortUtils.getArr(16, 100);
        int[] sortedArr = Arrays.copyOf(arr, n);
        Arrays.sort(sortedArr);
        SortThread s = new SortThread();
        SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);
        System.out.println("原数组为:");
        SortUtils.printArr(arr);
        System.out.println("排好序的数组为:");
        SortUtils.printArr(sortedArr);

        // 测试划分是否正确
        System.out.println("划分之后的数组为:");
        int[][] divArr = s.div(arr, 3);
        SortUtils.printArr(divArr);

        // 进行sortByauto的正确性测试
        int tn = 8;
        int[] ans1 = st.sortByAuto(s, arr, tn);
        SortUtils.printArr(ans1);
        System.out.println("sortByAuto是否正确:" + SortUtils.check(sortedArr, ans1));


        // 进行sortByAuto2的正确性测试
        tn = 8;
        int[] ans2 = Arrays.copyOf(arr, n);
        st.sortByAuto2(s, ans2, tn);
        SortUtils.printArr(ans2);
        System.out.println("sortByAuto2是否正确:" + SortUtils.check(sortedArr, ans2));

        // 进行线程池排序的测试
        tn = 8;
        int[] ans3 = Arrays.copyOf(arr, n);
        st.sortByPool(s, ans3, tn);
        SortUtils.printArr(ans3);
        System.out.println("线程池排序是否正确:" + SortUtils.check(sortedArr, ans3));

    }
    public static void runMethod() {
        int n = (int)1e8, m = (int)1e9;
        int[] arr = SortUtils.getArr(n, m);       // 生成数组
        SortThread s = new SortThread();
        SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);

        int[] sortArr = Arrays.copyOf(arr, n);
        st.sortBySingle(sortArr);                    // 进行一次标准排序,判断正误 + 统计时间

        // 测试a手动划分,然后使用堆进行合并的运行时间,这里使用了代理进行时间统计。
        int[] ans1 = st.threadSortByFutT(arr);     // 这里不对arr1打乱顺序
        System.out.println("threadSortByFutT答案是否正确:" + SortUtils.check(sortArr, ans1));

        //  测试方法1的正确性和运行时间
        int tn = 8;                      // 开启线程的个数
        int[] ans = st.sortByAuto(s, arr, tn);   // 不对arr改变顺序
        System.out.println("sortByAuto是否正确:" + SortUtils.check(ans, sortArr));

        // 测试方法2的正确性和运行的时间
        int[] ans2 = Arrays.copyOf(arr, n);       // 不改变原来的数组
        tn = 8;
        st.sortByAuto2(s, ans2, tn);
        System.out.println("sortByAuto2是否正确:" + SortUtils.check(ans2, sortArr));

        // 测试方法3的正确定和运行时间
        int[] ans3 = Arrays.copyOf(arr, n);
        tn = 8;
        st.sortByPool(s, ans3, tn);
        System.out.println("线程池的最大大小为;" + ((ThreadPoolExecutor)s.pool).getLargestPoolSize());
        System.out.println("sortByPool是否正确:" + SortUtils.check(ans3, sortArr));
    }
    public static void main(String[] args) {
//        test();
        runMethod();
    }
}

  • TimeProxy代理类
package threadBase.threadPool;

import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.MethodInterceptor;
import net.sf.cglib.proxy.MethodProxy;

import java.io.ObjectInput;
import java.lang.reflect.Method;

/**
 * @author: Zekun Fu
 * @date: 2022/5/21 15:39
 * @Description: 为了测试运行时间,写一个拦截器
 */
public class TimeProxy implements MethodInterceptor {
    /*
    *
    *
    *   使用cglib代理计算使用线程与不适用线程排序所用的时间
    * 无法把final方法重写
    * */
    private Enhancer enhancer = new Enhancer();

    public Object getProxy(Class<?>clz) {
        enhancer.setSuperclass(clz);
        enhancer.setCallback(this);
        return enhancer.create();               // 创建这个代理
    }

    public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
        long startTime = System.currentTimeMillis();
        Object ans = methodProxy.invokeSuper(o, objects);   // 调用o的方法
        long endTime = System.currentTimeMillis();
        // 指定为某个方法生成代理怎么指定呢?
        System.out.println("--------" + method.getName() + "的运行时间为:" + (endTime - startTime) + "ms" + "-------------");
        return ans;
    }
}

结果

  • 可以看到运行时间提升了4倍左右。
  • 使用新电脑能提升8倍左右。
  • 最大的应该是8和16倍。说明实际的数组的操作占用了大量的时间。

在这里插入图片描述

总结

  • cglib只能代理非finalpublic方法
  • 这里为了方便统计递归版本,使用了方法中放入SortThread类的引用。
  • 由于只是对这一个类进行代理,所以可以把invokeSuper改成SortThread的方法。也就把SortThrea类的引用放在代理类中。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值