近期学了Java多线程的一些关键字还有设计模式的代理模式,在这应用一下, 写一个多线程版本的归并排序。先写下版本1,以后在逐渐进行改进。
- 代码地址
- gitte
环境
- jdk 8+
- maven
- cglib
思路
这里主要实现了三个类。
- 进行数组生成和打印以及其他数组操作的SortUtil类。
- 进行时间统计的代理类TimeProxy,使用cglib生成代理类。
- 进行主逻辑实现的SortThread类。
归并的核心思路
- 划分数组,分成两部分,直到划分到小于等于给定len。
- 对于每一个划分,开启两个线程
- 在两个线程执行完成之后, 进行合并。
代码
- SortUtil类
package threadBase.threadPool;
import java.util.Random;
/**
* @author: Zekun Fu
* @date: 2022/5/21 11:07
* @Description: 使用线程池进行归并排序
*
*
* 设计:
* 1. 开启16个线程进行排序。
* 2. 进行划分
* 3. 使用快速排序进行排序。
* 4. 把16个结果放到堆里面,取出一个,然后维护堆和原本的堆就行了。
*
*
* 会快多少呢?
* 1. 首先16个线程并行运行,快16倍
* 2. 其次,每一个线程计算的时间复杂度是O(n/16 * log(n/16))。
* 3. 最后最进行合并的时候,需要O(nlog16)的时间复杂度。
* 4. 运行时间基本上可以看作是线性的。
*
*
*
*
*/
public class SortUtils {
private int[] tmp;
void merge_sort(int[] arr, int l, int r) {
if (l >= r) return ;
int mid = (l + r) >> 1;
merge_sort(arr, l, mid); // [l, mid]进行排序
merge_sort(arr, mid + 1, r); // [mid + 1, r]进行排序
int i = l, j = mid + 1, k = l;
while(i <= mid && j <= r) { // 选择较小的
tmp[k++] = arr[i] <= arr[j] ? arr[i++] : arr[j++];
}
while(i <= mid) tmp[k++] = arr[i++];
while(j <= r) tmp[k++] = arr[j++];
for (int t = l; t <= r; t++) arr[t] = tmp[t];
}
public static void printArr(int[] a) {
for (int x: a) System.out.print(x + " ");
System.out.println();
}
public static void printArr(int[][]arr) {
for (int[] tmp : arr) {
printArr(tmp);
}
}
public static int[] getArr(int n, int m) {
Random random = new Random();
int[] arr = new int[n];
for (int i = 0; i < n; i++) arr[i] = random.nextInt(m);
return arr;
}
/*'
* 判断两个数组是否相等
* */
public static boolean check(int[] arr1, int[] arr2) {
for (int i = 0; i < arr1.length; i++) {
if (arr1[i] != arr2[i]) return false;
}
return true;
}
public static void main(String[] args) {
int n = 10;
SortUtils th = new SortUtils();
th.tmp = new int[n];
int[] arr = getArr(n, 100);
printArr(arr);
th.merge_sort(arr, 0, n - 1);
printArr(arr);
}
}
- SortThread类
package threadBase.threadPool;
import javafx.util.Pair;
import net.sf.cglib.proxy.Enhancer;
import javax.xml.transform.Source;
import java.lang.reflect.Array;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author: Zekun Fu
* @date: 2022/5/21 14:41
* @Description: 分治排序的主线程
*
* 对于自动排序: 需要指定线程数量和每个线程处理数组的大小
* 对于自动排序的2:需要执行线程数量和.. 以及分配一个和原数组一样大的tmp数组。
*
*/
public class SortThread {
private ExecutorService pool; // 线程池
public SortThread() {
pool = Executors.newCachedThreadPool(); // 因为8核心就只有16个线程
}
private int[][] div(int[] arr, int d) { // 数组划分成多少份
int [][] ans = new int[d][];
int n = arr.length;
int len = n / d;
for (int i = 0; i < d; i++) {
int tmp[];
if(i != d - 1) tmp = new int[len];
else tmp = new int[n - (d - 1) * len];
for (int j = i * len; j < (i + 1) * len; j++) {
tmp[j - (i * len)] = arr[j];
}
ans[i] = tmp;
}
for (int i = d * len, j = len; i < n; i++, j++) {
ans[d - 1][j] = arr[i];
}
return ans;
}
/*
* 使用分治进行,每次分成两个线程进行排序,然后吧排好序的进行合并。
* 这里进行数组的划分,进行分支。
* */
private int[] sortByAuto(int[] arr, int l, int r, int len) throws Exception{
if (r - l + 1 <= len) {
int []tmp = Arrays.copyOfRange(arr, l, r + 1);
Arrays.sort(tmp);
return tmp;
}
int mid = (l + r) >> 1;
FutureTask<int[]> task1 = new FutureTask<int[]>(()->{
int[] ans = sortByAuto(arr, l, mid, len);
return ans;
});
FutureTask<int[]>task2 = new FutureTask<int[]>(()->{
int[] ans = sortByAuto(arr, mid + 1, r, len);
return ans;
});
new Thread(task1).start();
new Thread(task2).start();
// 由于是阻塞式的,所以再返回之前一定是已经完成了这两组的排序了
int[] tmpl = task1.get();
int[] tmpr = task2.get();
int[] ans = new int[r - l + 1];
int i = 0, j = 0, k = 0;
while(i < tmpl.length && j < tmpr.length) {
ans[k++] = tmpl[i] <= tmpr[j] ? tmpl[i++] : tmpr[j++];
}
while(i < tmpl.length) {
ans[k++] = tmpl[i++];
}
while(j < tmpr.length) {
ans[k++] = tmpr[j++];
}
return ans;
}
/*
*
* 不进行复制粘贴,直接使用原数组进行排序
* */
private void sortByAuto2(int[] arr, int l, int r, int len, int[] tmp) {
if (r - l + 1 <= len) {
Arrays.sort(arr, l, r + 1);
return;
}
int mid = (l + r) >> 1;
CountDownLatch latch = new CountDownLatch(2);
new Thread(()->{
sortByAuto2(arr, l, mid, len, tmp);
latch.countDown();
}).start();
new Thread(()->{
sortByAuto2(arr, mid + 1, r, len, tmp);
latch.countDown();
}).start();
try{
latch.await();
} catch (Exception e) {
e.printStackTrace();
}
int i = l, j = l, k = mid + 1;
while(j <= mid && k <= r) {
tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
}
while(j <= mid) {
tmp[i++] = arr[j++];
}
while(k <= r) {
tmp[i++] = arr[k++];
}
for (i = l; i <= r; i++) arr[i] = tmp[i];
}
/*
*
* 直接采用分治的方式进行
* 是一个错误的方法不知道怎么实现同步
* */
@Deprecated // 错误方法,别用了
private int[] sortByAuto3(int[] arr, int l, int r, int len) {
// 1. 递归基
int[] tmp = new int[r - l + 1];
for (int i = l, j = 0; i <= r; i++, j++) {
tmp[j] = arr[i];
}
if (r - l + 1 <= len) {
new Thread(()->{
Arrays.sort(tmp);
}).start();
return tmp;
}
// 2.进行左右划分
int mid = (l + r) >> 1;
int[] tmpl = sortByAuto3(arr, l, mid, len);
int[] tmpr = sortByAuto3(arr, mid + 1, r, len);
// 3. 进行合并, 必须等待所有的子线程排序完成。
int[] merge = new int[r - l + 1];
int i = l, j = 0, k = 0;
while(j < tmpl.length && k < tmpr.length) {
merge[i - l] = tmpl[j] <= tmpr[k] ? tmpl[j++] : tmpr[k++];
i++;
}
while(j < tmpl.length) {
merge[i - l] = tmpl[j++];
i++;
}
while(k < tmpr.length) {
merge[i - l] = tmpr[k++];
i++;
}
return merge;
}
/*
*
* 使用线程池进行数组的排序
* */
private void sortByPool(int[] arr, int l, int r, int len, int[] tmp) {
if (r - l <= len) {
Arrays.sort(arr, l, r + 1);
return ;
}
// 1. 进行划分
int mid = (l + r) >> 1;
Future<?> f1 = pool.submit(() -> sortByPool(arr, l, mid, len, tmp));
Future<?> f2 = pool.submit(()-> sortByPool(arr, mid + 1, r, len, tmp));
// 2. 等待两个线程执行完成
try {
f1.get();
f2.get();
} catch (Exception e) {
e.printStackTrace();
}
// 3. 进行合并即可。
int i = l, j = l, k = mid + 1;
while(j <= mid && k <= r) tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
while(j <= mid) tmp[i++] = arr[j++];
while(k <= r) tmp[i++] = arr[k++];
for (i = l; i <= r; i++) arr[i] = tmp[i];
}
/*
*
* 使用手动线程进行数组的排序
* */
private int[] collect(int[][]sortedArr, int size) {
class Node implements Comparable<Node>{
int num;
int ground; // 第几组
int id; // 第几个
public Node(int num, int group, int id) {
this.num = num;
this.ground = group;
this.id = id;
}
@Override
public int compareTo(Node o) {
return Integer.compare(this.num, o.num);
}
}
int[] ans = new int[size];
Queue<Node> heap = new PriorityQueue<>();
int d = sortedArr.length;
// 1.建立初始堆
for (int i = 0; i < d; i++) {
Node p = new Node(sortedArr[i][0], i, 0);
heap.add(p);
}
// 2.取顶,放入答案,重新建堆。
int idx = 0;
while(!heap.isEmpty()) {
Node top = heap.poll();
ans[idx ++] = top.num;
int g = top.ground, id = top.id;
if (id + 1 < sortedArr[g].length) { // 看是否为空
heap.add(new Node(sortedArr[g][id + 1], g, id + 1));
}
}
return ans;
}
public int[] threadSortByFutT(int[] arr) {
int n = 8; // 线程数量, 小于等于数组的长度。
int[][] divArr = div(arr, n);
AtomicInteger id = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(n);
for (int i = 0; i < n; i++) {
new Thread(()->{
Arrays.sort(divArr[id.getAndIncrement()]); // 给数组进行排序
latch.countDown();
}).start();
}
// 这里必须等所有数组排好序才能继续执行
try {
latch.await();
}catch (Exception e) {
e.printStackTrace();
}
// System.out.println("拍好序之后的数组");
// SortUtils.printArr(divArr);
int[] ans = collect(divArr, arr.length); // 得到一个新的数组
return ans;
}
public void sortBySingle(int[] arr) {
Arrays.sort(arr);
}
/*
*
* 为了方便代理的使用,
* 这里写一个方法调用被代理对象的方法,
* 然后执行的时间就是被代理对象的执行时间
* */
public int[] sortByAuto(SortThread s, int[] arr, int tn) {
int n = arr.length;
int len = n / tn; // 每个线程执需要排序的长度
int[] ans = null;
try {
ans = s.sortByAuto(arr, 0, n - 1, len); // 调用s的,防止递归调用打印。
} catch (Exception e) {
e.printStackTrace();
}finally {
return ans;
}
}
public void sortByAuto2(SortThread s, int[] arr, int tn) {
int n = arr.length;
int len = n / tn;
int[] tmp = new int[n];
s.sortByAuto2(arr, 0, n - 1, len, tmp); // 调用s的
}
public void sortByPool(SortThread s, int[] arr, int tn) {
int n = arr.length;
int len = n / tn;
int[] tmp = new int[n];
s.sortByPool(arr, 0, n - 1, len, tmp); // 调用s的
s.pool.shutdown();
}
// 进行辅助函数的测试
public static void test() {
int n = 16, m = 100;
int[] arr = SortUtils.getArr(16, 100);
int[] sortedArr = Arrays.copyOf(arr, n);
Arrays.sort(sortedArr);
SortThread s = new SortThread();
SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);
System.out.println("原数组为:");
SortUtils.printArr(arr);
System.out.println("排好序的数组为:");
SortUtils.printArr(sortedArr);
// 测试划分是否正确
System.out.println("划分之后的数组为:");
int[][] divArr = s.div(arr, 3);
SortUtils.printArr(divArr);
// 进行sortByauto的正确性测试
int tn = 8;
int[] ans1 = st.sortByAuto(s, arr, tn);
SortUtils.printArr(ans1);
System.out.println("sortByAuto是否正确:" + SortUtils.check(sortedArr, ans1));
// 进行sortByAuto2的正确性测试
tn = 8;
int[] ans2 = Arrays.copyOf(arr, n);
st.sortByAuto2(s, ans2, tn);
SortUtils.printArr(ans2);
System.out.println("sortByAuto2是否正确:" + SortUtils.check(sortedArr, ans2));
// 进行线程池排序的测试
tn = 8;
int[] ans3 = Arrays.copyOf(arr, n);
st.sortByPool(s, ans3, tn);
SortUtils.printArr(ans3);
System.out.println("线程池排序是否正确:" + SortUtils.check(sortedArr, ans3));
}
public static void runMethod() {
int n = (int)1e8, m = (int)1e9;
int[] arr = SortUtils.getArr(n, m); // 生成数组
SortThread s = new SortThread();
SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);
int[] sortArr = Arrays.copyOf(arr, n);
st.sortBySingle(sortArr); // 进行一次标准排序,判断正误 + 统计时间
// 测试a手动划分,然后使用堆进行合并的运行时间,这里使用了代理进行时间统计。
int[] ans1 = st.threadSortByFutT(arr); // 这里不对arr1打乱顺序
System.out.println("threadSortByFutT答案是否正确:" + SortUtils.check(sortArr, ans1));
// 测试方法1的正确性和运行时间
int tn = 8; // 开启线程的个数
int[] ans = st.sortByAuto(s, arr, tn); // 不对arr改变顺序
System.out.println("sortByAuto是否正确:" + SortUtils.check(ans, sortArr));
// 测试方法2的正确性和运行的时间
int[] ans2 = Arrays.copyOf(arr, n); // 不改变原来的数组
tn = 8;
st.sortByAuto2(s, ans2, tn);
System.out.println("sortByAuto2是否正确:" + SortUtils.check(ans2, sortArr));
// 测试方法3的正确定和运行时间
int[] ans3 = Arrays.copyOf(arr, n);
tn = 8;
st.sortByPool(s, ans3, tn);
System.out.println("线程池的最大大小为;" + ((ThreadPoolExecutor)s.pool).getLargestPoolSize());
System.out.println("sortByPool是否正确:" + SortUtils.check(ans3, sortArr));
}
public static void main(String[] args) {
// test();
runMethod();
}
}
- SortThread类
package threadBase.threadPool;
import javafx.util.Pair;
import net.sf.cglib.proxy.Enhancer;
import javax.xml.transform.Source;
import java.lang.reflect.Array;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author: Zekun Fu
* @date: 2022/5/21 14:41
* @Description: 分治排序的主线程
*
* 对于自动排序: 需要指定线程数量和每个线程处理数组的大小
* 对于自动排序的2:需要执行线程数量和.. 以及分配一个和原数组一样大的tmp数组。
*
*/
public class SortThread {
private ExecutorService pool; // 线程池
public SortThread() {
pool = Executors.newCachedThreadPool(); // 因为8核心就只有16个线程
}
private int[][] div(int[] arr, int d) { // 数组划分成多少份
int [][] ans = new int[d][];
int n = arr.length;
int len = n / d;
for (int i = 0; i < d; i++) {
int tmp[];
if(i != d - 1) tmp = new int[len];
else tmp = new int[n - (d - 1) * len];
for (int j = i * len; j < (i + 1) * len; j++) {
tmp[j - (i * len)] = arr[j];
}
ans[i] = tmp;
}
for (int i = d * len, j = len; i < n; i++, j++) {
ans[d - 1][j] = arr[i];
}
return ans;
}
/*
* 使用分治进行,每次分成两个线程进行排序,然后吧排好序的进行合并。
* 这里进行数组的划分,进行分支。
* */
private int[] sortByAuto(int[] arr, int l, int r, int len) throws Exception{
if (r - l + 1 <= len) {
int []tmp = Arrays.copyOfRange(arr, l, r + 1);
Arrays.sort(tmp);
return tmp;
}
int mid = (l + r) >> 1;
FutureTask<int[]> task1 = new FutureTask<int[]>(()->{
int[] ans = sortByAuto(arr, l, mid, len);
return ans;
});
FutureTask<int[]>task2 = new FutureTask<int[]>(()->{
int[] ans = sortByAuto(arr, mid + 1, r, len);
return ans;
});
new Thread(task1).start();
new Thread(task2).start();
// 由于是阻塞式的,所以再返回之前一定是已经完成了这两组的排序了
int[] tmpl = task1.get();
int[] tmpr = task2.get();
int[] ans = new int[r - l + 1];
int i = 0, j = 0, k = 0;
while(i < tmpl.length && j < tmpr.length) {
ans[k++] = tmpl[i] <= tmpr[j] ? tmpl[i++] : tmpr[j++];
}
while(i < tmpl.length) {
ans[k++] = tmpl[i++];
}
while(j < tmpr.length) {
ans[k++] = tmpr[j++];
}
return ans;
}
/*
*
* 不进行复制粘贴,直接使用原数组进行排序
* */
private void sortByAuto2(int[] arr, int l, int r, int len, int[] tmp) {
if (r - l + 1 <= len) {
Arrays.sort(arr, l, r + 1);
return;
}
int mid = (l + r) >> 1;
CountDownLatch latch = new CountDownLatch(2);
new Thread(()->{
sortByAuto2(arr, l, mid, len, tmp);
latch.countDown();
}).start();
new Thread(()->{
sortByAuto2(arr, mid + 1, r, len, tmp);
latch.countDown();
}).start();
try{
latch.await();
} catch (Exception e) {
e.printStackTrace();
}
int i = l, j = l, k = mid + 1;
while(j <= mid && k <= r) {
tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
}
while(j <= mid) {
tmp[i++] = arr[j++];
}
while(k <= r) {
tmp[i++] = arr[k++];
}
for (i = l; i <= r; i++) arr[i] = tmp[i];
}
/*
*
* 直接采用分治的方式进行
* 是一个错误的方法不知道怎么实现同步
* */
@Deprecated // 错误方法,别用了
private int[] sortByAuto3(int[] arr, int l, int r, int len) {
// 1. 递归基
int[] tmp = new int[r - l + 1];
for (int i = l, j = 0; i <= r; i++, j++) {
tmp[j] = arr[i];
}
if (r - l + 1 <= len) {
new Thread(()->{
Arrays.sort(tmp);
}).start();
return tmp;
}
// 2.进行左右划分
int mid = (l + r) >> 1;
int[] tmpl = sortByAuto3(arr, l, mid, len);
int[] tmpr = sortByAuto3(arr, mid + 1, r, len);
// 3. 进行合并, 必须等待所有的子线程排序完成。
int[] merge = new int[r - l + 1];
int i = l, j = 0, k = 0;
while(j < tmpl.length && k < tmpr.length) {
merge[i - l] = tmpl[j] <= tmpr[k] ? tmpl[j++] : tmpr[k++];
i++;
}
while(j < tmpl.length) {
merge[i - l] = tmpl[j++];
i++;
}
while(k < tmpr.length) {
merge[i - l] = tmpr[k++];
i++;
}
return merge;
}
/*
*
* 使用线程池进行数组的排序
* */
private void sortByPool(int[] arr, int l, int r, int len, int[] tmp) {
if (r - l <= len) {
Arrays.sort(arr, l, r + 1);
return ;
}
// 1. 进行划分
int mid = (l + r) >> 1;
Future<?> f1 = pool.submit(() -> sortByPool(arr, l, mid, len, tmp));
Future<?> f2 = pool.submit(()-> sortByPool(arr, mid + 1, r, len, tmp));
// 2. 等待两个线程执行完成
try {
f1.get();
f2.get();
} catch (Exception e) {
e.printStackTrace();
}
// 3. 进行合并即可。
int i = l, j = l, k = mid + 1;
while(j <= mid && k <= r) tmp[i++] = arr[j] <= arr[k] ? arr[j++] : arr[k++];
while(j <= mid) tmp[i++] = arr[j++];
while(k <= r) tmp[i++] = arr[k++];
for (i = l; i <= r; i++) arr[i] = tmp[i];
}
/*
*
* 使用手动线程进行数组的排序
* */
private int[] collect(int[][]sortedArr, int size) {
class Node implements Comparable<Node>{
int num;
int ground; // 第几组
int id; // 第几个
public Node(int num, int group, int id) {
this.num = num;
this.ground = group;
this.id = id;
}
@Override
public int compareTo(Node o) {
return Integer.compare(this.num, o.num);
}
}
int[] ans = new int[size];
Queue<Node> heap = new PriorityQueue<>();
int d = sortedArr.length;
// 1.建立初始堆
for (int i = 0; i < d; i++) {
Node p = new Node(sortedArr[i][0], i, 0);
heap.add(p);
}
// 2.取顶,放入答案,重新建堆。
int idx = 0;
while(!heap.isEmpty()) {
Node top = heap.poll();
ans[idx ++] = top.num;
int g = top.ground, id = top.id;
if (id + 1 < sortedArr[g].length) { // 看是否为空
heap.add(new Node(sortedArr[g][id + 1], g, id + 1));
}
}
return ans;
}
public int[] threadSortByFutT(int[] arr) {
int n = 8; // 线程数量, 小于等于数组的长度。
int[][] divArr = div(arr, n);
AtomicInteger id = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(n);
for (int i = 0; i < n; i++) {
new Thread(()->{
Arrays.sort(divArr[id.getAndIncrement()]); // 给数组进行排序
latch.countDown();
}).start();
}
// 这里必须等所有数组排好序才能继续执行
try {
latch.await();
}catch (Exception e) {
e.printStackTrace();
}
// System.out.println("拍好序之后的数组");
// SortUtils.printArr(divArr);
int[] ans = collect(divArr, arr.length); // 得到一个新的数组
return ans;
}
public void sortBySingle(int[] arr) {
Arrays.sort(arr);
}
/*
*
* 为了方便代理的使用,
* 这里写一个方法调用被代理对象的方法,
* 然后执行的时间就是被代理对象的执行时间
* */
public int[] sortByAuto(SortThread s, int[] arr, int tn) {
int n = arr.length;
int len = n / tn; // 每个线程执需要排序的长度
int[] ans = null;
try {
ans = s.sortByAuto(arr, 0, n - 1, len); // 调用s的,防止递归调用打印。
} catch (Exception e) {
e.printStackTrace();
}finally {
return ans;
}
}
public void sortByAuto2(SortThread s, int[] arr, int tn) {
int n = arr.length;
int len = n / tn;
int[] tmp = new int[n];
s.sortByAuto2(arr, 0, n - 1, len, tmp); // 调用s的
}
public void sortByPool(SortThread s, int[] arr, int tn) {
int n = arr.length;
int len = n / tn;
int[] tmp = new int[n];
s.sortByPool(arr, 0, n - 1, len, tmp); // 调用s的
s.pool.shutdown();
}
// 进行辅助函数的测试
public static void test() {
int n = 16, m = 100;
int[] arr = SortUtils.getArr(16, 100);
int[] sortedArr = Arrays.copyOf(arr, n);
Arrays.sort(sortedArr);
SortThread s = new SortThread();
SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);
System.out.println("原数组为:");
SortUtils.printArr(arr);
System.out.println("排好序的数组为:");
SortUtils.printArr(sortedArr);
// 测试划分是否正确
System.out.println("划分之后的数组为:");
int[][] divArr = s.div(arr, 3);
SortUtils.printArr(divArr);
// 进行sortByauto的正确性测试
int tn = 8;
int[] ans1 = st.sortByAuto(s, arr, tn);
SortUtils.printArr(ans1);
System.out.println("sortByAuto是否正确:" + SortUtils.check(sortedArr, ans1));
// 进行sortByAuto2的正确性测试
tn = 8;
int[] ans2 = Arrays.copyOf(arr, n);
st.sortByAuto2(s, ans2, tn);
SortUtils.printArr(ans2);
System.out.println("sortByAuto2是否正确:" + SortUtils.check(sortedArr, ans2));
// 进行线程池排序的测试
tn = 8;
int[] ans3 = Arrays.copyOf(arr, n);
st.sortByPool(s, ans3, tn);
SortUtils.printArr(ans3);
System.out.println("线程池排序是否正确:" + SortUtils.check(sortedArr, ans3));
}
public static void runMethod() {
int n = (int)1e8, m = (int)1e9;
int[] arr = SortUtils.getArr(n, m); // 生成数组
SortThread s = new SortThread();
SortThread st = (SortThread) new TimeProxy().getProxy(SortThread.class);
int[] sortArr = Arrays.copyOf(arr, n);
st.sortBySingle(sortArr); // 进行一次标准排序,判断正误 + 统计时间
// 测试a手动划分,然后使用堆进行合并的运行时间,这里使用了代理进行时间统计。
int[] ans1 = st.threadSortByFutT(arr); // 这里不对arr1打乱顺序
System.out.println("threadSortByFutT答案是否正确:" + SortUtils.check(sortArr, ans1));
// 测试方法1的正确性和运行时间
int tn = 8; // 开启线程的个数
int[] ans = st.sortByAuto(s, arr, tn); // 不对arr改变顺序
System.out.println("sortByAuto是否正确:" + SortUtils.check(ans, sortArr));
// 测试方法2的正确性和运行的时间
int[] ans2 = Arrays.copyOf(arr, n); // 不改变原来的数组
tn = 8;
st.sortByAuto2(s, ans2, tn);
System.out.println("sortByAuto2是否正确:" + SortUtils.check(ans2, sortArr));
// 测试方法3的正确定和运行时间
int[] ans3 = Arrays.copyOf(arr, n);
tn = 8;
st.sortByPool(s, ans3, tn);
System.out.println("线程池的最大大小为;" + ((ThreadPoolExecutor)s.pool).getLargestPoolSize());
System.out.println("sortByPool是否正确:" + SortUtils.check(ans3, sortArr));
}
public static void main(String[] args) {
// test();
runMethod();
}
}
- TimeProxy代理类
package threadBase.threadPool;
import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.MethodInterceptor;
import net.sf.cglib.proxy.MethodProxy;
import java.io.ObjectInput;
import java.lang.reflect.Method;
/**
* @author: Zekun Fu
* @date: 2022/5/21 15:39
* @Description: 为了测试运行时间,写一个拦截器
*/
public class TimeProxy implements MethodInterceptor {
/*
*
*
* 使用cglib代理计算使用线程与不适用线程排序所用的时间
* 无法把final方法重写
* */
private Enhancer enhancer = new Enhancer();
public Object getProxy(Class<?>clz) {
enhancer.setSuperclass(clz);
enhancer.setCallback(this);
return enhancer.create(); // 创建这个代理
}
public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
long startTime = System.currentTimeMillis();
Object ans = methodProxy.invokeSuper(o, objects); // 调用o的方法
long endTime = System.currentTimeMillis();
// 指定为某个方法生成代理怎么指定呢?
System.out.println("--------" + method.getName() + "的运行时间为:" + (endTime - startTime) + "ms" + "-------------");
return ans;
}
}
结果
- 可以看到运行时间提升了4倍左右。
- 使用新电脑能提升8倍左右。
- 最大的应该是8和16倍。说明实际的数组的操作占用了大量的时间。
总结
- cglib只能代理非final的public方法
- 这里为了方便统计递归版本,使用了方法中放入SortThread类的引用。
- 由于只是对这一个类进行代理,所以可以把invokeSuper改成SortThread的方法。也就把SortThrea类的引用放在代理类中。