Fork/Join是一个是一个并行计算的框架,主要用来支持分治任务模型
Fork 对应的是分治任务模型里的任务分解,Join 对应的是结果合并。
核心思想是将一个大任务分成许多小任务,然后并行执行这些小任务,最终将它们的结果合并成一个大的结果。
应用场景
- 递归分解型任务
- 数组处理
- 并行化算法
- 大数据处理
ForkJoinPool
ForkJoinPool是一个线程池,它用于管理ForkJoin任务的执行
主要方法:
- submit()提交任务
- invoke()执行任务
- shutdown()关闭线程池
- awaitTermination()等待任务的执行结果。
ForkJoinPool类中还包括一些参数,例如线程池的大小、工作线程的优先级、任务队列的容量等,可以根据具体的应用场景进行设置。
ForkJoinTask
ForkJoinTask是一个抽象类,用于表示可以被分割成更小部分的任务
使用
单线程归并排序
import java.util.Arrays;
public class MergeSort {
private final int[] arrayToSort; //要排序的数组
private final int threshold; //拆分的阈值,低于此阈值就不再进行拆分
public MergeSort(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}
/**
* 排序
* @return
*/
public int[] sequentialSort() {
return sequentialSort(arrayToSort, threshold);
}
public static int[] sequentialSort(final int[] arrayToSort, int threshold) {
//拆分后的数组长度小于阈值,直接进行排序
if (arrayToSort.length < threshold) {
//调用jdk提供的排序方法
Arrays.sort(arrayToSort);
return arrayToSort;
}
int midpoint = arrayToSort.length / 2;
//对数组进行拆分
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);
//递归调用
leftArray = sequentialSort(leftArray, threshold);
rightArray = sequentialSort(rightArray, threshold);
//合并排序结果
return merge(leftArray, rightArray);
}
public static int[] merge(final int[] leftArray, final int[] rightArray) {
//定义用于合并结果的数组
int[] mergedArray = new int[leftArray.length + rightArray.length];
int mergedArrayPos = 0;
int leftArrayPos = 0;
int rightArrayPos = 0;
while (leftArrayPos < leftArray.length && rightArrayPos < rightArray.length) {
if (leftArray[leftArrayPos] <= rightArray[rightArrayPos]) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
} else {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
}
mergedArrayPos++;
}
while (leftArrayPos < leftArray.length) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
mergedArrayPos++;
}
while (rightArrayPos < rightArray.length) {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
mergedArrayPos++;
}
return mergedArray;
}
}
Fork/Join框架归并排序
import java.util.Arrays;
import java.util.concurrent.RecursiveAction;
public class MergeSortTask extends RecursiveAction {
private final int threshold; //拆分的阈值,低于此阈值就不再进行拆分
private int[] arrayToSort; //要排序的数组
public MergeSortTask(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}
@Override
protected void compute() {
//拆分后的数组长度小于阈值,直接进行排序
if (arrayToSort.length <= threshold) {
// 调用jdk提供的排序方法
Arrays.sort(arrayToSort);
return;
}
// 对数组进行拆分
int midpoint = arrayToSort.length / 2;
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);
MergeSortTask leftTask = new MergeSortTask(leftArray, threshold);
MergeSortTask rightTask = new MergeSortTask(rightArray, threshold);
//调用任务
invokeAll(leftTask,rightTask);
// 合并排序结果
arrayToSort = MergeSort.merge(leftTask.getSortedArray(), rightTask.getSortedArray());
}
public int[] getSortedArray() {
return arrayToSort;
}
}
测试
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
public class ArrayToSortMain {
public static void main(String[] args) {
//生成测试数组 用于归并排序
int[] arrayToSortByMergeSort = buildRandomIntArray(20000000);
//生成测试数组 用于forkjoin排序
int[] arrayToSortByForkJoin = Arrays.copyOf(arrayToSortByMergeSort, arrayToSortByMergeSort.length);
//获取处理器数量
int processors = Runtime.getRuntime().availableProcessors();
MergeSort mergeSort = new MergeSort(arrayToSortByMergeSort, processors);
long startTime = System.nanoTime();
// 归并排序
mergeSort.sequentialSort();
long duration = System.nanoTime()-startTime;
System.out.println("单线程归并排序时间: "+(duration/(1000f*1000f))+"毫秒");
//利用forkjoin排序
MergeSortTask mergeSortTask = new MergeSortTask(arrayToSortByForkJoin, processors);
//构建forkjoin线程池
ForkJoinPool forkJoinPool = new ForkJoinPool(processors);
startTime = System.nanoTime();
//执行排序任务
forkJoinPool.invoke(mergeSortTask);
duration = System.nanoTime()-startTime;
System.out.println("forkjoin排序时间: "+(duration/(1000f*1000f))+"毫秒");
}
/**
* 随机生成数组
* @param size 数组的大小
* @return
*/
public static int[] buildRandomIntArray(final int size) {
int[] arrayToCalculateSumOf = new int[size];
Random generator = new Random();
for (int i = 0; i < arrayToCalculateSumOf.length; i++) {
arrayToCalculateSumOf[i] = generator.nextInt(100000000);
}
return arrayToCalculateSumOf;
}
}
结果如下:
原理
待补充...