线程池ForkJoinPool工作原理分析
算法题:如何充分利用多核CPU的性能,快速对一个2千万大小的数组进行排序?
基于归并排序算法实现
归并排序(Merge Sort)是一种基于分治思想的排序算法
归并排序的时间复杂度为O(nlogn),空间复杂度为O(n),其中n为数组的长度
分治思想是将一个规模为N的问题分解为K个规模较小的子问题,这些子问题相互独立且与原问题性质相同。求出子问题的解,就可得到原问题的解。步骤为:分解 → 求解 → 合并。
单线程实现归并排序
Fork/Join并行归并排序
并行实现归并排序的优化和注意事项
在实际应用中,我们需要考虑数据分布的均匀性、内存使用情况、线程切换开销等因素,以充分利用多核CPU并保证算法的正确性和效率。
- 任务的大小:任务大小的选择会影响并行算法的效率和负载均衡,如果任务太小,会造成任务划分和合并的开销过大;如果任务太大,会导致任务无法充分利用多核CPU并行处理能力。
- 负载均衡:并行算法需要保证负载均衡,即各个线程执行的任务大小和时间应该尽可能相等,否则会导致某些线程负载过重,而其他线程负载过轻的情况。
- 数据分布:数据分布的均匀性也会影响并行算法的效率和负载均衡。
- 内存使用:并行算法需要考虑内存的使用情况,特别是在处理大规模数据时,内存的使用情况会对算法的执行效率产生重要影响。
- 线程切换:线程切换是并行算法的一个重要开销,需要尽量减少线程的切换次数,以提高算法的执行效率。
Fork/Join框架介绍
Fork/Join是一个是一个并行计算的框架,主要就是用来支持分治任务模型的,这个计算框架里的 Fork 对应的是分治任务模型里的任务分解,Join 对应的是结果合并。它的核心思想是将一个大任务分成许多小任务,然后并行执行这些小任务,最终将它们的结果合并成一个大的结果。
应用场景
- 用于递归分解型的任务,例如排序、归并、遍历等。
- 用于数组的处理,例如数组的排序、查找、统计等。
- 用于并行化算法的实现,例如并行化的图像处理算法、并行化的机器学习算法等。
- 用于大数据处理,例如大型日志文件的处理、大型数据库的查询等。
Fork/Join使用
import java.util.Arrays;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
public class ForkJoinPoolDemo {
public static void main(String[] args) {
var d = new ForkJoinPoolDemo();
//============
// d.fibonacciDemo();
//======归并排序演示======
d.mergeDemo();
var i = 0;
for (; ; ) {
i++;
if (i > 50000) {
break;
}
}
System.out.println("======= main end =======");
}
void fibonacciDemo() {
ForkJoinPool pool = new ForkJoinPool();
FibonacciDemo task = new FibonacciDemo(10);
//提交任务并一直阻塞直到任务 执行完成返回合并结果。
int result = pool.invoke(task);
System.out.println(result);
}
private class FibonacciDemo extends RecursiveTask<Integer> {
final int n;
public FibonacciDemo(int i) {
n = i;
}
@Override
protected Integer compute() {
if (n <= 1)
return n;
FibonacciDemo f1 = new FibonacciDemo(n - 1);
//提交任务
f1.fork();
FibonacciDemo f2 = new FibonacciDemo(n - 2);
//合并结果
return f2.compute() + f1.join();
}
}
void mergeDemo() {
//生成测试数组 用于归并排序
int[] arrayToSortByMergeSort1 = Utils.buildRandomIntArray(2000000);
int[] arrayToSortByMergeSort2 = Arrays.copyOf(arrayToSortByMergeSort1, arrayToSortByMergeSort1.length);
int[] arrayToSortByMergeSort3 = Arrays.copyOf(arrayToSortByMergeSort1, arrayToSortByMergeSort1.length);
//获取处理器数量
int processors = Runtime.getRuntime().availableProcessors();
System.out.println(processors);
MergeSortTask mergeSortTask = new MergeSortTask(arrayToSortByMergeSort3, processors * 1000);
//构建 ForkJoinPool
ForkJoinPool forkJoinPool = new ForkJoinPool(processors);
// System.out.println("排序前数组");
// System.out.println(Arrays.toString(arrayToSortByMergeSort));
// System.out.println("排序后数组");
// System.out.println(Arrays.toString(sortedArr));
runSort("单线程归并排序时间: ", () -> MergeSort.sequentialSort(arrayToSortByMergeSort1, 10000));
runSort("快排排序时间: ", () -> Arrays.sort(arrayToSortByMergeSort2));
runSort("多线程归并排序时间: ", () -> forkJoinPool.invoke(mergeSortTask));
}
void runSort(String name, Runnable task) {
long startTime = System.currentTimeMillis();
task.run();
long duration = System.currentTimeMillis() - startTime;
System.out.println(name + (duration) + "毫秒");
}
private class MergeSortTask extends RecursiveAction {
private final int threshold; //拆分的阈值,低于此阈值就不再进行拆分
private int[] arrayToSort; //要排序的数组
public MergeSortTask(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}
@Override
protected void compute() {
//拆分后的数组长度小于阈值,直接进行排序
if (arrayToSort.length <= threshold) {
// 调用jdk提供的排序方法
Arrays.sort(arrayToSort);
return;
}
// 对数组进行拆分
int midpoint = arrayToSort.length / 2;
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);
MergeSortTask leftTask = new MergeSortTask(leftArray, threshold);
MergeSortTask rightTask = new MergeSortTask(rightArray, threshold);
//调用任务
invokeAll(leftTask, rightTask);
//合并排序结果
arrayToSort = MergeSort.merge(leftTask.getSortedArray(), rightTask.getSortedArray());
}
public int[] getSortedArray() {
return arrayToSort;
}
}
public class MergeSort {
private final int[] arrayToSort; //要排序的数组
private final int threshold; //拆分的阈值,低于此阈值就不再进行拆分
public MergeSort(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}
/**
* 排序
*
* @return
*/
public int[] sequentialSort() {
return sequentialSort(arrayToSort, threshold);
}
public static int[] sequentialSort(final int[] arrayToSort, int threshold) {
//拆分后的数组长度小于阈值,直接进行排序
if (arrayToSort.length < threshold) {
//调用jdk提供的排序方法
Arrays.sort(arrayToSort);
return arrayToSort;
}
int midpoint = arrayToSort.length / 2;
//对数组进行拆分
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);
//递归调用
leftArray = sequentialSort(leftArray, threshold);
rightArray = sequentialSort(rightArray, threshold);
//合并排序结果
return merge(leftArray, rightArray);
}
public static int[] merge(final int[] leftArray, final int[] rightArray) {
//定义用于合并结果的数组
int[] mergedArray = new int[leftArray.length + rightArray.length];
int mergedArrayPos = 0;
int leftArrayPos = 0;
int rightArrayPos = 0;
while (leftArrayPos < leftArray.length && rightArrayPos < rightArray.length) {
if (leftArray[leftArrayPos] <= rightArray[rightArrayPos]) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
} else {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
}
mergedArrayPos++;
}
while (leftArrayPos < leftArray.length) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
mergedArrayPos++;
}
while (rightArrayPos < rightArray.length) {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
mergedArrayPos++;
}
return mergedArray;
}
}
}
总结
Fork/Join是一种基于分治思想的模型,在并发处理计算型任务时有着显著的优势。
- 任务切分:将大的任务分割成更小粒度的小任务,让更多的线程参与执行;
- 任务窃取:通过任务窃取,充分地利用空闲线程,并减少竞争。