public class MergeSortAction extends RecursiveAction {
private final int threshold;
private int[] arrayToSort;
public MergeSortAction(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}
@Override
protected void compute() {
if (arrayToSort.length <= threshold) {
Arrays.sort(arrayToSort);
return;
}
int midpoint = arrayToSort.length / 2;
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);
MergeSortAction left = new MergeSortAction(leftArray, threshold);
MergeSortAction right = new MergeSortAction(rightArray, threshold);
//invokeAll(left, right);
left.fork();
right.fork();
left.join();
right.join();
arrayToSort = MergeSortMain.merge(left.getSortedArray(), right.getSortedArray());
}
public int[] getSortedArray() {
return arrayToSort;
}
}
public class MergeSortMain {
private final int[] arrayToSort;
private final int threshold;
public MergeSortMain(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}
public int[] sequentialSort() {
return sequentialSort(arrayToSort, threshold);
}
public static int[] sequentialSort(final int[] arrayToSort, int threshold) {
if (arrayToSort.length < threshold) {
Arrays.sort(arrayToSort);
return arrayToSort;
}
int midpoint = arrayToSort.length / 2;
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);
leftArray = sequentialSort(leftArray, threshold);
rightArray = sequentialSort(rightArray, threshold);
return merge(leftArray, rightArray);
}
public static int[] merge(final int[] leftArray, final int[] rightArray) {
int[] mergedArray = new int[leftArray.length + rightArray.length];
int mergedArrayPos = 0;
int leftArrayPos = 0;
int rightArrayPos = 0;
while (leftArrayPos < leftArray.length && rightArrayPos < rightArray.length) {
if (leftArray[leftArrayPos] <= rightArray[rightArrayPos]) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
} else {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
}
mergedArrayPos++;
}
while (leftArrayPos < leftArray.length) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
mergedArrayPos++;
}
while (rightArrayPos < rightArray.length) {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
mergedArrayPos++;
}
return mergedArray;
}
public static void main(String[] args) {
int[] arrayToSort = Utils.buildRandomIntArray(20000000);
int[] expectedArray = Arrays.copyOf(arrayToSort, arrayToSort.length);
int nofProcessors = Runtime.getRuntime().availableProcessors();
MergeSortMain shortestPathServiceSeq = new MergeSortMain(arrayToSort, nofProcessors);
int[] actualArray = shortestPathServiceSeq.sequentialSort();
Arrays.sort(expectedArray);
assertThat(actualArray, is(expectedArray));
int[] arrayToSortSingleThread = Utils.buildRandomIntArray(20000000);
int[] arrayToSortMultiThread = Arrays.copyOf(arrayToSortSingleThread, arrayToSortSingleThread.length);
nofProcessors = Runtime.getRuntime().availableProcessors();
// SINGLE THREADED
shortestPathServiceSeq = new MergeSortMain(arrayToSortSingleThread, nofProcessors);
int[] sortSingleThreadArray = shortestPathServiceSeq.sequentialSort();
// MULTI THREADED
MergeSortAction mergeSortAction = new MergeSortAction(arrayToSortMultiThread, nofProcessors);
ForkJoinPool forkJoinPool = new ForkJoinPool(nofProcessors);
forkJoinPool.invoke(mergeSortAction);
assertArrayEquals(sortSingleThreadArray, mergeSortAction.getSortedArray());
}
}
这是我看到的一个ForkJoin的demo,我觉得很不错,贴出来给大家分享分享