多核时代,编程语言如果不支持多核编程就OUT了,Java为了迎头赶上,在Java 8 版本增加大量支持多核编程的类库,如Stream等,Java 7开始支持的ForkJoin框架也是为了更好的支持多核编程。


设计思想:化整为零再化零为整,另外还要加上一种团队精神,即能者多劳。化整为零(split up)就是把一个复杂的任务分为许多足够小的任务计算;化零为整(merge)就是把小任务的计算结果不断往上合并值到得出最终结果;团队精神:ForkJoin使用了Work-Stealing算法,即先完成任务的线程不会闲着,会主动去偷别的线程待处理任务队列中的任务来帮忙处理,直到全部任务都处理完大伙才能停下来休息。


使用ForkJoin框架经常使用到两个类RecursiveTask 和 RecursiveAction,RecursiveTask 用于定义有返回值的任务,RecursiveAction用于定义没有返回值的任务,从类名看这两个类应该跟递归有一腿?经确认,ForkJoin框架处理的任务基本都能使用递归处理,比如求斐波那契数列等,但递归算法的缺陷是:一只会只用单线程处理,二是递归次数过多时会导致堆栈溢出;ForkJoin解决了这两个问题,使用多线程并发处理,充分利用计算资源来提高效率,同时避免堆栈溢出发生。当然像求斐波那契数列这种小问题直接使用线性算法搞定可能更简单,实际应用中完全没必要使用ForkJoin框架,所以ForkJoin是核弹,是用来对付大家伙的,比如超大数组排序。


最佳应用场景:多核、多内存、可以分割计算再合并的计算密集型任务。


ForkJoinTask类的几个重要方法:


fork()方法:将任务放入队列并安排异步执行,一个任务应该只调用一次fork()函数,除非已经执行完毕并重新初始化。


tryUnfork()方法:尝试把任务从队列中拿出单独处理,但不一定成功。


join()方法:等待计算完成并返回计算结果。


isCompletedAbnormally()方法:用于判断任务计算是否发生异常。


ForkJoinPool的使用与其它ExecutorService类似。


示例代码:

package com.stevex.app.forkjoin;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;

public class ForkJoinTest {
	public static void main(String[] args) {
		long beginTime = System.nanoTime();		
		System.out.println("The sum from 1 to 1000 is " + sum(1, 1000));
		System.out.println("Time consumed(nano second) By recursive algorithm : " + (System.nanoTime() - beginTime));
		
		
		beginTime = System.nanoTime();	
		System.out.println("The sum from 1 to 1000000000 is " + sum1(1, 1000000000));	
		System.out.println("Time consumed(nano second) By loop algorithm : " + (System.nanoTime() - beginTime));
		
		
		ForkJoinTest app = new ForkJoinTest();
		ForkJoinPool forkJoinPool = new ForkJoinPool();
		CountTask task = app.new CountTask(1,1000000000);
		beginTime = System.nanoTime();
		Future<Long> result = forkJoinPool.submit(task);
		try{
			System.out.println("The sum from 1 to 1000000000 is " + result.get());			
		}
		catch(Exception e){
			e.printStackTrace();
		}
		
		System.out.println("Time consumed(nano second) By ForkJoin algorithm : " + (System.nanoTime() - beginTime));
	}

	private static long sum1(long start, long end) {
		long s = 0l;
		
		for(long i=start; i<= end; i++){
			s += i;
		}
		
		return s;
	}

	private static long sum(long start, long end){
		if(end > start){
			return end + sum(start, end-1);
		}
		else{
			return start;
		}
	}
	
	private class CountTask extends RecursiveTask<Long>{
		private static final int THRESHOLD = 10000;
		private int start;
		private int end;
		
		public CountTask(int start, int end){
			this.start = start;
			this.end = end;
		}
		
		protected Long compute(){
			//System.out.println("Thread ID: " + Thread.currentThread().getId());
			
			Long sum = 0l;
			
			if((end -start) <= THRESHOLD){
				sum = sum1(start, end);
			}
			else{
				int middle = (start + end) / 2;
				CountTask leftTask = new CountTask(start, middle);
				CountTask rightTask = new CountTask(middle + 1, end);
				leftTask.fork();
				rightTask.fork();
				
				Long leftResult = leftTask.join();
				Long rightResult = rightTask.join();
				
				sum = leftResult + rightResult;
			}
			
			return sum;
		}
	}
}