目录
Fork/Join介绍
Fork/Join框架是一种分而治之思想的多线程版实现,可以理解为单机下的map/reduce任务,将一个完整的任务taskA,按照某种条件切割成N个子任务,通过fork将任务分布到多个不同的Thread运行,并通过join来获取taskA的完整运行结果。
Java中的Fork/Join框架实现,是通过ForkJoinPool和ForkJoinTask实现的,ForkJoinPool是一个动态线程池,ForkJoinTask<T>代表一个task,T表示task的返回值类型,ForkJoinTask的实现类有三个:
- java.util.concurrent.RecursiveTask<T>:对task的返回值敏感的话,可以使用这个实现类
- java.util.concurrent.RecursiveAction:不需要获取task返回值的场景,可以使用这个实现类
- java.util.concurrent.CountedCompleter<T>:这个是java8新增的ForkJoinTask的实现类,这个类相较于前两者,最大的特点是当stage complete时,可以触发相应的任务,例如taskA被切割成taskB、taskC,而taskB又被切割成taskB1、taskB2,那么当taskB1执行完,可以触发onCompletionTaskB1,当taskB1和taskB2都执行完时,可以触发onCompletionTaskB。
ForkJoinPool
ForkJoinPool是专门为Fork/Join框架设计的线程池,它与Executor不同,可以动态增减线程池中线程的数量。
ForkJoinPool内部有一个内部类java.util.concurrent.ForkJoinPool.WorkQueue,这个数据结构与work-stealing algorithm(工作窃取算法)是整个java fork/join框架的精髓所在。
在编码过程中,我们可以手动去创建一个ForkJoinPool来运行ForkJoinTask,但java自身也提供了一个静态的commonPool实现,当运行ForkJoinTask但不指定ForkJoinPool时,默认使用全局的commonPool来运行。
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism),
checkFactory(factory),
handler,
asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
"ForkJoinPool-" + nextPoolId() + "-worker-");
checkPermission();
}
private static int checkParallelism(int parallelism) {
if (parallelism <= 0 || parallelism > MAX_CAP) //max_cap=0x7fff
throw new IllegalArgumentException();
return parallelism;
}
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
public static ForkJoinPool commonPool() {
// assert common != null : "static init error";
return common;
}
ForkJoinPool构造函数的parameter解释:
- parallelism:并行度,也就是pool中线程的数量,默认等于运行服务器的core数量,注意:并行度有上限,最大值不能超过2^15-1(32767),否则抛出异常。
- factory:用于创建fork/join thread的工厂类,需要注意:不同于Executor的threadFactory,这个factory返回的是java.util.concurrent.ForkJoinWorkerThread类,而不是java.lang.Thread
- handler:thread的未捕获异常处理器的具体实现,默认是null
- asyncMode:默认值是false,用于控制WorkQueue的数据结构类型,默认情况下是栈结构,当设置为true时,使用队列结构。
对于WorkQueue的栈结构和队列结构,各自的使用场景:
- 栈结构一般用于分而治之思想的任务处理,这种处理方式能够优先运行更细粒度的task,而不是等到所有粗粒度的task运行结束,才执行细粒度task,可以减少空间消耗。
- 队列结构一般用于处理event-style的任务,但这种模式下,使用Executor也可以实现,因此一般也不会主动去设置这个参数。
当ForkJoinPool shutdown后或者没有足够的资源运行新提交的task时,再提交task就会抛出RejectedExecutionException 异常。
common pool
common pool与new ForkJoinPool()不同的地方在于:
- common pool在调用shutdown或者shutdownNow时,不会修改pool state,也就是说,除非调用System.exit(0),否则common pool会一直接收ForkJoinTask并运行
- common pool会占用main线程作为其中一个thread worker
common pool可以通过System.setProperty来控制参数:
- java.util.concurrent.ForkJoinPool.common.parallelism:>=0,如果设置为0,表示禁用common pool,但这样做可能导致未执行任务永远不会执行
- java.util.concurrent.ForkJoinPool.common.threadFactory:ForkJoinPool.ForkJoinWorkerThreadFactory 的实现类名称,如果设置null,会禁用common pool
- java.util.concurrent.ForkJoinPool.common.exceptionHandler:Thread.UncaughtExceptionHandler的类名
ForkJoinTask
ForkJoinTask中最重要的方法就是exec()方法,ForkJoinTask在被调度运行时,触发的是doExec()方法
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
exec()方法负责运行计算任务,同时也负责异常处理和整个task的运行状态,这个方法的返回值是boolean类型,当为true时,task才会被设置为NORMAL结束标志,但需要注意:ForkJoinTask的实现类CountedCompleter的exec方法返回值一直是false,因此对于CountedCompleter来说,整个task的运行状态必须依赖tryComplete()方法来控制。
我们在使用ForkJoinTask的时候,一般都是继承ForkJoinTask的abstract subclass,也就是java.util.concurrent.RecursiveTask<T>、java.util.concurrent.RecursiveAction和java.util.concurrent.CountedCompleter<T>,并重写compute()方法,下面给出一个用于计算的累加和demo:
package com.forkjoin;
import com.sun.istack.internal.NotNull;
import java.util.concurrent.*;
/**
* 研究forkjoinpool的内部源码用
*/
public class ForkJoinPoolAnalyze1 extends RecursiveTask<Integer> {
private int[] intvals;
public ForkJoinPoolAnalyze1(@NotNull int[] intvals) {
this.intvals = intvals;
}
public static void main(String[] args) throws NoSuchFieldException {
int[] intvals = new int[400];
for (int i = 0; i < 400; i++) {
intvals[i] = i;
}
ForkJoinPoolAnalyze1 forkJoinPoolAnalyze1 = new ForkJoinPoolAnalyze1(intvals);
ForkJoinPool forkJoinPool = new ForkJoinPool();
long start = System.currentTimeMillis();
try {
ForkJoinTask<Integer> future = forkJoinPool.submit(forkJoinPoolAnalyze1);
int result = future.join();
System.out.println("计算结果:" + result + ",总共耗时:" + (System.currentTimeMillis() - start) + " ms");
} catch (Exception e) {
System.out.println("出错误啦!!!");
e.printStackTrace();
}
}
@Override
protected Integer compute() {
if (intvals.length <= 5) {
int result = 0;
for (int i : intvals) {
// int x = 1/i; //1/0的问题
try {
TimeUnit.MILLISECONDS.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
result += i;
}
System.out.println("ThreadName:" + Thread.currentThread().getName() + ",计算结果:" + result);
return result;
} else {
int size = intvals.length / 5;
int modval = intvals.length % 5;
if (modval > 0)
size += 1;
ForkJoinPoolAnalyze1[] forkJoinPoolAnalyze1s = new ForkJoinPoolAnalyze1[size];
for (int i = 0; i < size; i++) {
int arrsize = 5;
if (modval > 0 && i == size - 1) {
arrsize = modval;
}
int[] newIntArrays = new int[arrsize];
System.arraycopy(this.intvals, i * 5, newIntArrays, 0, arrsize);
ForkJoinPoolAnalyze1 forkJoinPoolAnalyze1 = new ForkJoinPoolAnalyze1(newIntArrays);
forkJoinPoolAnalyze1s[i] = forkJoinPoolAnalyze1;
}
invokeAll(forkJoinPoolAnalyze1s);
int result = 0;
for (ForkJoinPoolAnalyze1 forkJoinPoolAnalyze1 : forkJoinPoolAnalyze1s) {
result += forkJoinPoolAnalyze1.getRawResult();
}
return result;
}
}
}
CountedCompleter
这个实现类的最大特点是需要维护一个pending counter,用于记录每个task的当前阻塞情况,CountedCompleter类似于一个树形结构,Main Thread提交的CountedCompleter属于root,特点就是CountedCompleter.completer=null;而由其fork出来的task,comleter=root,以此类推。
protected CountedCompleter(CountedCompleter<?> completer,
int initialPendingCount) {
this.completer = completer; //直接父节点
this.pending = initialPendingCount;
}
CountedCompleter中维护一个pendingCount,用于记录当前task的pending数量,只有当this.pendingCount归0后,才认为当前的task运行结束,才有机会触发parent node task的完成,例如root中存在两个fork task,root.pendingCount=2,只有当fork task都执行结束时,才会触发root task的的完成,并将完整的task设置成完成状态,tryComplete()方法的源码如下:
public final void tryComplete() {
CountedCompleter<?> a = this, s = a;
for (int c;;) {
if ((c = a.pending) == 0) {
a.onCompletion(s);
if ((a = (s = a).completer) == null) {
s.quietlyComplete();
return;
}
}
else if (U.compareAndSwapInt(a, PENDING, c, c - 1))
return;
}
}
通过源码可以看到,只有root task执行结束,才认为整个forkjointask执行结束,而root task要执行结束,首先要求fork task要将pendingCount设置为0,并显示的触发一次tryComplete()方法。
注意:pendingCount可以手动设置,因此并不一定要与fork task的数量相等,而要完成一个task,也并非必须通过tryComplete方法,也可以通过complete方法直接完成task
/**
* 可以看到,任意的一个task都可以执行complete来直接完成整个countedcompleter task
*/
public void complete(T rawResult) {
CountedCompleter<?> p;
setRawResult(rawResult);
onCompletion(this);
quietlyComplete();
if ((p = completer) != null)
p.tryComplete();
}
使用CountedCompleter计算的累加和demo:
package com.forkjoin;
import java.util.Random;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @Author: wanglonglong
* @Description: 使用CountedCompleter计算0~12的累加和
*/
public class CountedCompleterTest1 extends CountedCompleter<Integer> {
public static AtomicInteger sum_val = new AtomicInteger(0);
private final int[] intvals;
private int totalNum = 0;
private String threadName;
public CountedCompleterTest1(int[] intvals,CountedCompleterTest1 parentCountedCompleter) {
super(parentCountedCompleter);
this.intvals = intvals;
}
@Override
public Integer getRawResult() {
return this.totalNum;
}
@Override
public void onCompletion(CountedCompleter<?> caller) {
CountedCompleterTest1 countedCompleterTest1 = (CountedCompleterTest1) caller;
System.out.println("line[32]---"+this.threadName+".onCompletion("+countedCompleterTest1.threadName+")");
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
int[] intvals = new int[13];
for (int i = 0; i < 13; i++) {
intvals[i] = i;
}
CountedCompleter<Integer> countedCompleter = new CountedCompleterTest1(intvals,null);
countedCompleter.invoke();
System.out.println("line[42]---总和="+CountedCompleterTest1.sum_val.get());
}
@Override
public void compute() {
this.threadName = Thread.currentThread().getName();
String caculateStr = "";
if (intvals.length <= 5) {
for (int i : intvals) {
caculateStr += i+"+";
try {
TimeUnit.MILLISECONDS.sleep(new Random().nextInt(10));
} catch (InterruptedException e) {
e.printStackTrace();
}
this.totalNum += i;
}
System.out.println("line[59]---"+"ThreadName="+this.threadName+"的计算表达式:"+caculateStr.substring(0,caculateStr.length()-1)+"="+this.totalNum);
sum_val.addAndGet(this.totalNum);
} else {
int size = intvals.length / 5;
int modval = intvals.length % 5;
if (modval > 0)
size += 1;
for (int i = 0; i < size; i++) {
int arrsize = 5;
if (modval > 0 && i == size - 1) {
arrsize = modval;
}
int[] newIntArrays = new int[arrsize];
System.arraycopy(this.intvals, i * 5, newIntArrays, 0, arrsize);
CountedCompleterTest1 countedCompleterTest1 = new CountedCompleterTest1(newIntArrays,this);
countedCompleterTest1.fork();
}
this.setPendingCount(size);
}
/*
* 注意:getPendingCount()是瞬时结果,并不能准确反映真实情况
* */
/*System.out.println(this.threadName+" 尝试完成任务,当前pendingCount="+this.getPendingCount()+",parent pendingCount="
+(this.getRoot()==null?"null":this.getRoot().getPendingCount()));*/
tryComplete();
}
}
ManagedBlocker
ManagedBlocker是ForkJoinPool的内部类,其源码如下:
public static void managedBlock(ManagedBlocker blocker)
throws InterruptedException {
ForkJoinPool p;
ForkJoinWorkerThread wt;
Thread t = Thread.currentThread();
if ((t instanceof ForkJoinWorkerThread) &&
(p = (wt = (ForkJoinWorkerThread)t).pool) != null) {
WorkQueue w = wt.workQueue;
while (!blocker.isReleasable()) {
if (p.tryCompensate(w)) {
try {
do {} while (!blocker.isReleasable() &&
!blocker.block());
} finally {
U.getAndAddLong(p, CTL, AC_UNIT);
}
break;
}
}
}
else {
do {} while (!blocker.isReleasable() &&
!blocker.block());
}
ManagedBlocker用于阻塞ForkJoinTask任务,相较于Thread.wait等操作,ManagedBlocker在阻塞时,提供一个补偿机制,会在阻塞的同时,尝试创建新的线程来弥补ForkJoinPool中被占用的线程。
ManagedBlocker的实现类需要实现isReleasable()和block()方法,isReleasable用于判断是否需要阻塞,block执行具体的阻塞和唤醒操作。
demo:
package com.forkjoin;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;
/**
* @Author: wanglonglong
* @Description:
*/
public class BlockerTest1 extends RecursiveAction {
private int[] intvals;
private long start = System.currentTimeMillis();
public BlockerTest1(int[] intvals) {
this.intvals = intvals;
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
int[] intvals = new int[20];
for (int i = 0; i < 20; i++) {
intvals[i] = i;
}
BlockerTest1 blockerTest1 = new BlockerTest1(intvals);
ForkJoinPool forkJoinPool = new ForkJoinPool(2);
forkJoinPool.submit(blockerTest1);
blockerTest1.get();
}
@Override
protected void compute() {
if (intvals.length <= 5) {
int result = 0;
for (int i : intvals) {
/*try {
TimeUnit.MILLISECONDS.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}*/
try {
ForkJoinPool.managedBlock(new Blocker());
} catch (InterruptedException e) {
e.printStackTrace();
}
result += i;
}
/**
* 使用sleep暂停的输出结果:
* ThreadName:ForkJoinPool-1-worker-1,计算结果:10,计算耗时:2504
* ThreadName:ForkJoinPool-1-worker-0,计算结果:85,计算耗时:2504
* ThreadName:ForkJoinPool-1-worker-1,计算结果:35,计算耗时:5007
* ThreadName:ForkJoinPool-1-worker-0,计算结果:60,计算耗时:5007
*
* 使用blocker暂停的输出结果:
* ThreadName:ForkJoinPool-1-worker-0,计算结果:85,计算耗时:2506
* ThreadName:ForkJoinPool-1-worker-2,计算结果:60,计算耗时:2553
* ThreadName:ForkJoinPool-1-worker-1,计算结果:10,计算耗时:2601
* ThreadName:ForkJoinPool-1-worker-3,计算结果:35,计算耗时:2601
*
*/
System.out.println("ThreadName:" + Thread.currentThread().getName() + ",计算结果:" + result + ",计算耗时:" +(System.currentTimeMillis()-start));
} else {
int size = intvals.length / 5;
int modval = intvals.length % 5;
if (modval > 0)
size += 1;
BlockerTest1[] blockerTest1s = new BlockerTest1[size];
for (int i = 0; i < size; i++) {
int arrsize = 5;
if (modval > 0 && i == size - 1) {
arrsize = modval;
}
int[] newIntArrays = new int[arrsize];
System.arraycopy(this.intvals, i * 5, newIntArrays, 0, arrsize);
BlockerTest1 blockerTest1 = new BlockerTest1(newIntArrays);
blockerTest1s[i] = blockerTest1;
}
invokeAll(blockerTest1s);
}
}
static class Blocker implements ForkJoinPool.ManagedBlocker {
volatile boolean flag = false;
@Override
public boolean block() throws InterruptedException {
TimeUnit.MILLISECONDS.sleep(500);
this.flag = true;
return true;
}
@Override
public boolean isReleasable() {
return flag;
}
}
}
工作窃取算法
work-stealing是fork/join的精髓,本人水平有限,算法的源码实在让人头秃,因此通过各方资料来分析算法的实现原理。
首先是来自国内网站的一篇文章:Java多线程进阶(四三)—— J.U.C之executors框架:Fork/Join框架(1) 原理 - SegmentFault 思否,这篇文章用图例的方式生动的展示了工作窃取算法的工作流程,但注意:文章中存在一些错误,因此只需要大概阅读以下,对工作窃取算法有个大概的认识即可。
工作窃取算法是Doug Lea大神的作品,大神发表了一篇有关算法的论文,地址:http://gee.cs.oswego.edu/dl/papers/fj.pdf,这个是对工作窃取算法最权威的解释。
我根据上面两篇资料进行了整理,如果有错误的地方欢迎指正。
fork/join框架的核心是work-stealing algorithm(任务窃取算法),算法的核心是Thread Worker Array和WorkQueue Array,其中Thread Worker由ForkJoinPool的defaultThreadFactory来创建,但需要注意:不同于其他的多线程框架,fork/join框架的worker thread是可以回收的。
WorkQueue Array主要的作用就是记录所有的Task,task根据产生方式可以分为external task和inner task:
- external task表示由外部代码提交的任务,例如由用户代码通过ForkJoinPool.submit(ForkJoinTask)提交的任务,此时的ForkJoinTask就是external task
- inner task表示由ForkJoinPool中的workerThread提交的任务,例如在RecursiveTask.compute方法中,调用了(new MyRecursiveTask()).fork(),此时MyRecursiveTask就是inner task
workQueue Array的偶数位保存的是external task,而奇数位保存的是peer Thread Worker的inner task。
在ForkJoinPool.asyncMode=false(默认值)时,无论是external task还是inner task,都使用LIFO(也就是栈结构)来保存task;当asyncMode=true时,使用FIFO的算法保存,这样其实也很好理解,就是fork/join框架优先要处理粒度更小的task,用来防止workqueue中保存了大量的task才开始计算。
在工作窃取算法中,WorkQueue Array中存储着每个Thread独占的WorkQueue,而WorkQueue是一个由数组和其他信息维护的双向链表结构,核心的成员属性是int base和int top,其抽象结构可以表示成如下图:
对于WorkQueue,只有owner thread才有全系push task,但获取task的操作任何thread都可以,这也就是工作窃取名称的由来,但owner提取task的方式和other thread提取task的方式不同。
- push task的时候,base不变,top++
- owner thread pop task时,base不变,top--
- other thread poll(take) task时,top不变,base++,也就是说,任务窃取使用的队列结构
对于external workqueue,只会被poll出数据,由ForkJointWorkerThread从base开始窃取external task并运行;
对于inner task,可以被pop或者poll,其中ForkJointWorkerThread的主要工作就是从奇数位的workqueue(这个workqueue与ForkJointWorkerThread是绑定的,one-to-one)提取task并运行,只有当private innter task workqueue都执行结束后,才尝试使用poll从其他workqueue中poll task。