今天第一次听到Fork/Join框架,感觉自己low爆了,连这个都不知道,所以回来就抓紧补知识,我们先看一下别人对这个框架的原理的介绍
Fork/Join框架是Java 7提供的一个用于并行执行任务的框架,是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。Fork/Join框架要完成两件事情:
1.任务分割:首先Fork/Join框架需要把大的任务分割成足够小的子任务,如果子任务比较大的话还要对子任务进行继续分割
2.执行任务并合并结果:分割的子任务分别放到双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都放在另外一个队列里,启动一个线程从队列里取数据,然后合并这些数据。
在Java的Fork/Join框架中,使用两个类完成上述操作
1.ForkJoinTask:我们要使用Fork/Join框架,首先需要创建一个ForkJoin任务。该类提供了在任务中执行fork和join的机制。通常情况下我们不需要直接集成ForkJoinTask类,只需要继承它的子类,Fork/Join框架提供了两个子类:
a.RecursiveAction:用于没有返回结果的任务
b.RecursiveTask:用于有返回结果的任务
2.ForkJoinPool:ForkJoinTask需要通过ForkJoinPool来执行
任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务(工作窃取算法)。
Fork/Join框架的实现原理
ForkJoinPool由ForkJoinTask数组和ForkJoinWorkerThread数组组成,ForkJoinTask数组负责将存放程序提交给ForkJoinPool,而ForkJoinWorkerThread负责执行这些任务。
接下来是我们这次分析的源代码:
public class CountTask extends RecursiveTask<Integer>{
private static final int THREAD_HOLD = 2;
private int start;
private int end;
public CountTask(int start,int end){
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
int sum = 0;
//如果任务足够小就计算
boolean canCompute = (end - start) <= THREAD_HOLD;
if(canCompute){
for(int i=start;i<=end;i++){
sum += i;
}
}else{
int middle = (start + end) / 2;
CountTask left = new CountTask(start,middle);
CountTask right = new CountTask(middle+1,end);
//执行子任务
left.fork();
right.fork();
//获取子任务结果
int lResult = left.join();
int rResult = right.join();
sum = lResult + rResult;
}
return sum;
}
public static void main(String[] args){
ForkJoinPool pool = new ForkJoinPool();
CountTask task = new CountTask(1,4);
Future<Integer> result = pool.submit(task);
try {
System.out.println(result.get());
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
}
首先我们创建了一个ForkJoinPool,然后新建了一个任务,将这个任务丢进池子里面去执行,我们的任务继承了RecursiveTask,在里面实现了compute,主要是处理任务的逻辑,在里面会将我们放进的任务根据具体的要求进行拆分,变成更小的任务去执行,最后合并结果。
我们接下来就对这个框架的源码进行分析:
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
获取到了java虚拟机可以使用的最大处理器的个数,并且传进去了一个defaultForkJoinWorkerThreadFactory:
static final class DefaultForkJoinWorkerThreadFactory
implements ForkJoinWorkerThreadFactory {
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new ForkJoinWorkerThread(pool);
}
}
这个默认的工厂实现了ForkJoinWorkerThreadFactory接口,里面只有一个方法,newThead,参数是一个pool,创建一个新的线程,并返回了它:
protected ForkJoinWorkerThread(ForkJoinPool pool) {
// Use a placeholder until a useful name can be set in registerWorker
super("aForkJoinWorkerThread");
this.pool = pool;
this.workQueue = pool.registerWorker(this);
}
它继承了Thread,在构造方法里面保存了这个pool,通过pool注册当前的类返回了一个工作队列:
final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
//发生异常的一个回调函数
UncaughtExceptionHandler handler;
//设置守护进程
wt.setDaemon(true);
//给当前线程设置异常回调函数
if ((handler = ueh) != null)
wt.setUncaughtExceptionHandler(handler);
//新建一个工作队列
WorkQueue w = new WorkQueue(this, wt);
int i = 0;
int mode = config & MODE_MASK;
//获取运行状态锁;返回当前(锁定)运行状态。
int rs = lockRunState();
try {
WorkQueue[] ws; int n; // skip if no array
//通过散列将新创建的工作队列放到工作队列的集合当中
if ((ws = workQueues) != null && (n = ws.length) > 0) {
int s = indexSeed += SEED_INCREMENT; // unlikely to collide
int m = n - 1;
i = ((s << 1) | 1) & m; // odd-numbered indices
if (ws[i] != null) { // collision
int probes = 0; // step by approx half n
int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
while (ws[i = (i + step) & m] != null) {
if (++probes >= n) {
workQueues = ws = Arrays.copyOf(ws, n <<= 1);
m = n - 1;
probes = 0;
}
}
}
w.hint = s; // use as random seed
w.config = i | mode;
w.scanState = i; // publication fence
ws[i] = w;
}
} finally {
unlockRunState(rs, rs & ~RSLOCK);
}
wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
//返回
return w;
}
在这个方法里面首先个当前线程设置了一个发生异常的回调函数,然后生成一个工作队列和当前线程绑定,最后将生成的工作队列通过散列的方式放到工作队列的集合当中去,并返回它。
这样一个处理任务的pool就创建好了,接下来我们看一下是如何执行任务的:
Future<Integer> result = pool.submit(task);
进入submit方法:
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task;
}
首先将这个任务放到任务队列里面去执行,执行完毕后返回结果:
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
//返回当前线程的探测值
int r = ThreadLocalRandom.getProbe();
//得到当前线程池的工作状态
int rs = runState;
//如果当前工作队列的集合不为空
if ((ws = workQueues) != null &&
//如果当前工作队列的集合的size大于0
(m = (ws.length - 1)) >= 0 &&
//当前线程对应的散列的位置存在对应的工作队列
(q = ws[m & r & SQMASK]) != null &&
//探测值不为0
r != 0 &&
//线程池处于工作状态
rs > 0 &&
//成功修改了工作队列的运行状态
U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a; int am, n, s;
//将集合数组化
if ((a = q.array) != null &&
//数组里面是否由元素,由于workQueue将索引放在数组的中心(尚未分配)
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
U.putIntVolatile(q, QLOCK, 0);
if (n <= 1)
//试图创建或激活一个工人,如果太少活跃。
signalWork(ws, q);
return;
}
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
//用于处理不常见的情况,以及在向池首次提交第一个任务时执行二级初始化
externalSubmit(task);
}
如果满足if里面的各种条件,先将任务添加到对应的工作队列当中,如果是第一次初始化或者一些不常见的情况进行其他处理:
final void signalWork(WorkQueue[] ws, WorkQueue q) {
long c; int sp, i; WorkQueue v; Thread p;
//如果当前的ctl小于0,说明活跃的线程太少了
while ((c = ctl) < 0L) {
// 如果强制转换后等于0,说明没有空闲的工人
if ((sp = (int)c) == 0) {
//需要添加工人
if ((c & ADD_WORKER) != 0L) // too few workers
tryAddWorker(c);
break;
}
//下面这几种情况说明线程池停止工作了
if (ws == null)
break;
if (ws.length <= (i = sp & SMASK)) // terminated
break;
if ((v = ws[i]) == null) // terminating
break;
int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
int d = sp - v.scanState; // screen CAS
long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
//说明有闲置的工人,需要唤醒当前这个工人
if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
v.scanState = vs; // activate v
if ((p = v.parker) != null)
U.unpark(p);
break;
}
if (q != null && q.base == q.top) // no more work
break;
}
}
这个方法根据线程池中工人的状态做出判断,如果没有闲置的工人,就添加工人,如果有闲置的工人就唤醒他,最主要的是根据ctl的低16位来看现在的没有被挂起的线程的数目:
private void tryAddWorker(long c) {
boolean add = false;
do {
long nc = ((AC_MASK & (c + AC_UNIT)) |
(TC_MASK & (c + TC_UNIT)));
if (ctl == c) {
int rs, stop;
//检查是否线程池是否停止工作
if ((stop = (rs = lockRunState()) & STOP) == 0)
add = U.compareAndSwapLong(this, CTL, c, nc);
unlockRunState(rs, rs & ~RSLOCK);
if (stop != 0)
break;
if (add) {
//创建工人
createWorker();
break;
}
}
} while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}
循环中每创建要给工人就要修改ctl的值,接下来我们看一下工人是如何创建的:
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
if (fac != null && (wt = fac.newThread(this)) != null) {
wt.start();
return true;
}
} catch (Throwable rex) {
ex = rex;
}
deregisterWorker(wt, ex);
return false;
}
其实就是创建一个一个线程并启动它,我们接下来看一下这个线程里面的run方法做了什么:
public void run() {
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
//模板方法,启动前做一些处理
onStart();
//启动工人去完成工作队列的任务
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
final void runWorker(WorkQueue w) {
w.growArray(); // allocate queue
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
//不断的工作队列中获得任务去执行
for (ForkJoinTask<?> t;;) {
if ((t = scan(w, r)) != null)
w.runTask(t);
else if (!awaitWork(w, r))
break;
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
每次循环都重新生成一个r,作为参数同任务队列里面获得要给任务去执行:
private ForkJoinTask<?> scan(WorkQueue w, int r) {
WorkQueue[] ws; int m;
//任务队列集合不为空
if ((ws = workQueues) != null &&
//size大于0
(m = ws.length - 1) > 0 &&
//当前线程绑定的任务队列不为空
w != null) {
int ss = w.scanState; // initially non-negative
for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
int b, n; long c;
//根据算出的散列值获得一个工作队列
if ((q = ws[k]) != null) {
if ((n = (b = q.base) - q.top) < 0 &&
(a = q.array) != null) { // non-empty
long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
//从工作队列中获得一个任务并返回
if ((t = ((ForkJoinTask<?>)
U.getObjectVolatile(a, i))) != null &&
q.base == b) {
if (ss >= 0) {
if (U.compareAndSwapObject(a, i, t, null)) {
q.base = b + 1;
if (n < -1) // signal others
signalWork(ws, q);
return t;
}
}
else if (oldSum == 0 && // try to activate
w.scanState < 0)
tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
}
if (ss < 0) // refresh
ss = w.scanState;
r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
origin = k = r & m; // move and rescan
oldSum = checkSum = 0;
continue;
}
checkSum += b;
}
...
}
}
return null;
}
先判断当前线程是否处于未激活状态,如果是未激活就尝试激活它,从其他队列中偷取一个任务去执行
final void runTask(ForkJoinTask<?> task) {
if (task != null) {
//标记当前线程处于工作状态
scanState &= ~SCANNING;
//执行任务并返回结果
(currentSteal = task).doExec();
U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
execLocalTasks();
ForkJoinWorkerThread thread = owner;
if (++nsteals < 0) // collect on overflow
transferStealCount(pool);
scanState |= SCANNING;
if (thread != null)
thread.afterTopLevelExec();
}
}
在这个方法里面首先回去执行窃取的任务,然后执行自己队列里面的任务,在这里面我们最关心的方法就是doExec:
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
//执行任务并获取执行结果
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
//如果执行成功设置执行成功的标志
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
将执行任务的最终任务交给了exec:
protected final boolean exec() {
result = compute();
return true;
}
就是执行我们自己的实现的逻辑了:
protected Integer compute() {
int sum = 0;
//如果任务足够小就计算
boolean canCompute = (end - start) <= THREAD_HOLD;
if(canCompute){
for(int i=start;i<=end;i++){
sum += i;
}
}else{
int middle = (start + end) / 2;
CountTask left = new CountTask(start,middle);
CountTask right = new CountTask(middle+1,end);
//执行子任务
left.fork();
right.fork();
//获取子任务结果
int lResult = left.join();
int rResult = right.join();
sum = lResult + rResult;
}
return sum;
}
首先根据我们传入的任务的大小做一个划分,如果太大了继续将他们切割变成子任务,否则就直接执行,我们看一下fork是如何执行子任务的:
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
如果是当前线程池创建的线程就直接放到当前线程的任务队列中:
final void push(ForkJoinTask<?> task) {
ForkJoinTask<?>[] a; ForkJoinPool p;
int b = base, s = top, n;
if ((a = array) != null) { // ignore if queue removed
int m = a.length - 1; // fenced write for task visibility
U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
U.putOrderedInt(this, QTOP, s + 1);
if ((n = s - b) <= 1) {
if ((p = pool) != null)
p.signalWork(p.workQueues, this);
}
else if (n >= m)
growArray();
}
}
通过原子操作放到任务队列中,并且通知工人来处理它
接下来看一下join方法:
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
调用doJoin方法,如果执行成功,返回结果:
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
//得到这个task的运行状态
return (s = status) < 0 ? s :
//判断是否是我们框架的线程
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
//如果当前工作队列的top元素是当前的task,从队列中直接拿出来执行
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) &&
(s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
在这个方法中会首先查看工作队列的顶部元素是否是当前放入的task,如果是就直接执行并返回结果,否则,阻塞到worker轮询到当前的任务再去执行。最后递归返回子任务的处理结果,合并到一起,返回最终的处理结果。