import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
public class SumTask extends RecursiveTask<Long> {
static final int THRESHOLD = 100;
long[] array;
int start;
int end;
SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
if (end - start <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
System.out.println(String.format("compute %d~%d = %d", start, end, sum));
return sum;
}
int middle = (start + end) / 2;
System.out.println(String.format("split%d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
SumTask subTask1 = new SumTask(this.array, start, middle);
SumTask subTask2 = new SumTask(this.array, middle, end);
invokeAll(subTask1, subTask2);
Long subResult1 = subTask1.join();
Long subResult2 = subTask2.join();
long result = subResult1 + subResult2;
System.out.println("result = " + subResult1 + " + " + subResult2 + " ==> " + result);
return result;
}
public static void main(String[] args) {
long[] array = new long[800];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1;
}
// fork/join task
ForkJoinPool fjp = new ForkJoinPool(4);
ForkJoinTask<Long> task = new SumTask(array, 0, array.length);
long startTime = System.currentTimeMillis();
long result = fjp.invoke(task);
long endTime = System.currentTimeMillis();
System.out.println("Fork/Join sum: " + result + " in: " + (endTime - startTime) + " ms");
}
}
程序执行结果