package com.huang.test.concurrent;
import java.util.concurrent.*;
/**
* ForkJoin:分治处理问题。
*/
public class ForkJoinTest {
public static void main(String[] args) {
ForkJoinTest fjt = new ForkJoinTest();
fjt.test();
}
class MyTask extends RecursiveTask<Long>
{
private static final long serialVersionUID = -2358776679174639606L;
long[] array;
private int start;
private int end;
MyTask(long[] array, int start, int end)
{
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
if(end - start <= 10)
{
long sum = 0;
for(int i = start;i <= end;i++)
{
sum += array[i];
}
return sum;
}else
{
int middle = (end + start)/2;
MyTask taskLeft = new MyTask(array, start, middle);
MyTask taskRight = new MyTask(array, middle + 1, end);
taskLeft.fork();
taskRight.fork();
long leftSum = taskLeft.join();
long rightSum = taskRight.join();
return leftSum + rightSum;
}
}
}
private void test() {
ForkJoinPool pool = new ForkJoinPool(16);
long[] arr = new long[100000000];
for(int i = 0;i < arr.length;i++)
{
arr[i] = i + 654632563456754l;
}
MyTask task = new MyTask(arr, 0, arr.length - 1);
long start = System.nanoTime();
ForkJoinTask<Long> fur = pool.submit(task);
try {
pool.shutdown();
long res = fur.get();
long end = System.nanoTime();
System.out.println("res:" + res + ", use time:" + (end - start));
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
testSum(arr);
}
private void testSum(long[] arr)
{
long start = System.nanoTime();
long sum = 0;
for(int i = 0;i < arr.length;i++)
{
sum += arr[i];
}
long end = System.nanoTime();
System.out.println("test sum:" + sum + ", use time:" + (end - start));
}
}