Fork/Join框架

今天第一次听到Fork/Join框架,感觉自己low爆了,连这个都不知道,所以回来就抓紧补知识,我们先看一下别人对这个框架的原理的介绍

Fork/Join框架是Java 7提供的一个用于并行执行任务的框架,是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。Fork/Join框架要完成两件事情:
1.任务分割:首先Fork/Join框架需要把大的任务分割成足够小的子任务,如果子任务比较大的话还要对子任务进行继续分割
 2.执行任务并合并结果:分割的子任务分别放到双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都放在另外一个队列里,启动一个线程从队列里取数据,然后合并这些数据。
 在Java的Fork/Join框架中,使用两个类完成上述操作
 1.ForkJoinTask:我们要使用Fork/Join框架,首先需要创建一个ForkJoin任务。该类提供了在任务中执行fork和join的机制。通常情况下我们不需要直接集成ForkJoinTask类,只需要继承它的子类,Fork/Join框架提供了两个子类:
a.RecursiveAction:用于没有返回结果的任务
b.RecursiveTask:用于有返回结果的任务
2.ForkJoinPool:ForkJoinTask需要通过ForkJoinPool来执行
任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务(工作窃取算法)。
Fork/Join框架的实现原理
ForkJoinPool由ForkJoinTask数组和ForkJoinWorkerThread数组组成,ForkJoinTask数组负责将存放程序提交给ForkJoinPool,而ForkJoinWorkerThread负责执行这些任务。

接下来是我们这次分析的源代码:

public class CountTask extends RecursiveTask<Integer>{

    private static final int THREAD_HOLD = 2;

    private int start;
    private int end;

    public CountTask(int start,int end){
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        int sum = 0;
        //如果任务足够小就计算
        boolean canCompute = (end - start) <= THREAD_HOLD;
        if(canCompute){
            for(int i=start;i<=end;i++){
                sum += i;
            }
        }else{
            int middle = (start + end) / 2;
            CountTask left = new CountTask(start,middle);
            CountTask right = new CountTask(middle+1,end);
            //执行子任务
            left.fork();
            right.fork();
            //获取子任务结果
            int lResult = left.join();
            int rResult = right.join();
            sum = lResult + rResult;
        }
        return sum;
    }

    public static void main(String[] args){
        ForkJoinPool pool = new ForkJoinPool();
        CountTask task = new CountTask(1,4);
        Future<Integer> result = pool.submit(task);
        try {
            System.out.println(result.get());
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
    }
}

首先我们创建了一个ForkJoinPool,然后新建了一个任务,将这个任务丢进池子里面去执行,我们的任务继承了RecursiveTask,在里面实现了compute,主要是处理任务的逻辑,在里面会将我们放进的任务根据具体的要求进行拆分,变成更小的任务去执行,最后合并结果。
我们接下来就对这个框架的源码进行分析:

public ForkJoinPool() {
        this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
             defaultForkJoinWorkerThreadFactory, null, false);
    }

获取到了java虚拟机可以使用的最大处理器的个数,并且传进去了一个defaultForkJoinWorkerThreadFactory:

static final class DefaultForkJoinWorkerThreadFactory
        implements ForkJoinWorkerThreadFactory {
        public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
            return new ForkJoinWorkerThread(pool);
        }
    }

这个默认的工厂实现了ForkJoinWorkerThreadFactory接口,里面只有一个方法,newThead,参数是一个pool,创建一个新的线程,并返回了它:

protected ForkJoinWorkerThread(ForkJoinPool pool) {
        // Use a placeholder until a useful name can be set in registerWorker
        super("aForkJoinWorkerThread");
        this.pool = pool;
        this.workQueue = pool.registerWorker(this);
    }

它继承了Thread,在构造方法里面保存了这个pool,通过pool注册当前的类返回了一个工作队列:

final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
		//发生异常的一个回调函数
        UncaughtExceptionHandler handler;
        //设置守护进程
        wt.setDaemon(true); 
        //给当前线程设置异常回调函数
        if ((handler = ueh) != null)
            wt.setUncaughtExceptionHandler(handler);
        //新建一个工作队列
        WorkQueue w = new WorkQueue(this, wt);
        int i = 0; 
        int mode = config & MODE_MASK;
        //获取运行状态锁;返回当前(锁定)运行状态。
        int rs = lockRunState();
        try {
            WorkQueue[] ws; int n;                    // skip if no array
            //通过散列将新创建的工作队列放到工作队列的集合当中
            if ((ws = workQueues) != null && (n = ws.length) > 0) {
                int s = indexSeed += SEED_INCREMENT;  // unlikely to collide
                int m = n - 1;
                i = ((s << 1) | 1) & m;               // odd-numbered indices
                if (ws[i] != null) {                  // collision
                    int probes = 0;                   // step by approx half n
                    int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
                    while (ws[i = (i + step) & m] != null) {
                        if (++probes >= n) {
                            workQueues = ws = Arrays.copyOf(ws, n <<= 1);
                            m = n - 1;
                            probes = 0;
                        }
                    }
                }
                w.hint = s;                           // use as random seed
                w.config = i | mode;
                w.scanState = i;                      // publication fence
                ws[i] = w;
            }
        } finally {
            unlockRunState(rs, rs & ~RSLOCK);
        }
        wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
        //返回
        return w;
    }

在这个方法里面首先个当前线程设置了一个发生异常的回调函数,然后生成一个工作队列和当前线程绑定,最后将生成的工作队列通过散列的方式放到工作队列的集合当中去,并返回它。
这样一个处理任务的pool就创建好了,接下来我们看一下是如何执行任务的:

Future<Integer> result = pool.submit(task);

进入submit方法:

public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }

首先将这个任务放到任务队列里面去执行,执行完毕后返回结果:

final void externalPush(ForkJoinTask<?> task) {
        WorkQueue[] ws; WorkQueue q; int m;
        //返回当前线程的探测值
        int r = ThreadLocalRandom.getProbe();
        //得到当前线程池的工作状态
        int rs = runState;
        //如果当前工作队列的集合不为空
        if ((ws = workQueues) != null && 
        	//如果当前工作队列的集合的size大于0
        	(m = (ws.length - 1)) >= 0 &&
        	//当前线程对应的散列的位置存在对应的工作队列
            (q = ws[m & r & SQMASK]) != null &&
            //探测值不为0
             r != 0 && 
             //线程池处于工作状态
             rs > 0 &&
             //成功修改了工作队列的运行状态
            U.compareAndSwapInt(q, QLOCK, 0, 1)) {
            ForkJoinTask<?>[] a; int am, n, s;
            //将集合数组化
            if ((a = q.array) != null &&
            	//数组里面是否由元素,由于workQueue将索引放在数组的中心(尚未分配)
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                int j = ((am & s) << ASHIFT) + ABASE;
                U.putOrderedObject(a, j, task);
                U.putOrderedInt(q, QTOP, s + 1);
                U.putIntVolatile(q, QLOCK, 0);
                if (n <= 1)
                	//试图创建或激活一个工人,如果太少活跃。
                    signalWork(ws, q);
                return;
            }
            U.compareAndSwapInt(q, QLOCK, 1, 0);
        }
        //用于处理不常见的情况,以及在向池首次提交第一个任务时执行二级初始化
        externalSubmit(task);
    }

如果满足if里面的各种条件,先将任务添加到对应的工作队列当中,如果是第一次初始化或者一些不常见的情况进行其他处理:

 final void signalWork(WorkQueue[] ws, WorkQueue q) {
        long c; int sp, i; WorkQueue v; Thread p;
        //如果当前的ctl小于0,说明活跃的线程太少了
        while ((c = ctl) < 0L) { 
        	// 如果强制转换后等于0,说明没有空闲的工人
            if ((sp = (int)c) == 0) {
            	//需要添加工人
                if ((c & ADD_WORKER) != 0L)            // too few workers
                    tryAddWorker(c);
                break;
            }
            //下面这几种情况说明线程池停止工作了
            if (ws == null)  
                break;
            if (ws.length <= (i = sp & SMASK))         // terminated
                break;
            if ((v = ws[i]) == null)                   // terminating
                break;
            int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
            int d = sp - v.scanState;                  // screen CAS
            long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
            //说明有闲置的工人,需要唤醒当前这个工人
            if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
                v.scanState = vs;                      // activate v
                if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
            if (q != null && q.base == q.top)          // no more work
                break;
        }
    }

这个方法根据线程池中工人的状态做出判断,如果没有闲置的工人,就添加工人,如果有闲置的工人就唤醒他,最主要的是根据ctl的低16位来看现在的没有被挂起的线程的数目:

private void tryAddWorker(long c) {
        boolean add = false;
        do {
            long nc = ((AC_MASK & (c + AC_UNIT)) |
                       (TC_MASK & (c + TC_UNIT)));
            if (ctl == c) {
                int rs, stop; 
                //检查是否线程池是否停止工作
                if ((stop = (rs = lockRunState()) & STOP) == 0)
                    add = U.compareAndSwapLong(this, CTL, c, nc);
                unlockRunState(rs, rs & ~RSLOCK);
                if (stop != 0)
                    break;
                if (add) {
                	//创建工人
                    createWorker();
                    break;
                }
            }
        } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
    }

循环中每创建要给工人就要修改ctl的值,接下来我们看一下工人是如何创建的:

private boolean createWorker() {
        ForkJoinWorkerThreadFactory fac = factory;
        Throwable ex = null;
        ForkJoinWorkerThread wt = null;
        try {
            if (fac != null && (wt = fac.newThread(this)) != null) {
                wt.start();
                return true;
            }
        } catch (Throwable rex) {
            ex = rex;
        }
        deregisterWorker(wt, ex);
        return false;
    }

其实就是创建一个一个线程并启动它,我们接下来看一下这个线程里面的run方法做了什么:

public void run() {
        if (workQueue.array == null) { // only run once
            Throwable exception = null;
            try {
            	//模板方法,启动前做一些处理
                onStart();
                //启动工人去完成工作队列的任务
                pool.runWorker(workQueue);
            } catch (Throwable ex) {
                exception = ex;
            } finally {
                try {
                    onTermination(exception);
                } catch (Throwable ex) {
                    if (exception == null)
                        exception = ex;
                } finally {
                    pool.deregisterWorker(this, exception);
                }
            }
        }
    }
final void runWorker(WorkQueue w) {
        w.growArray();                   // allocate queue
        int seed = w.hint;               // initially holds randomization hint
        int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
        //不断的工作队列中获得任务去执行
        for (ForkJoinTask<?> t;;) {
            if ((t = scan(w, r)) != null)
                w.runTask(t);
            else if (!awaitWork(w, r))
                break;
            r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
        }
    }

每次循环都重新生成一个r,作为参数同任务队列里面获得要给任务去执行:

private ForkJoinTask<?> scan(WorkQueue w, int r) {
        WorkQueue[] ws; int m;
        //任务队列集合不为空
        if ((ws = workQueues) != null && 
        //size大于0
        (m = ws.length - 1) > 0 && 
        //当前线程绑定的任务队列不为空
        w != null) {
            int ss = w.scanState;                     // initially non-negative
            for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
                WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
                int b, n; long c;
                //根据算出的散列值获得一个工作队列
                if ((q = ws[k]) != null) {
                    if ((n = (b = q.base) - q.top) < 0 &&
                        (a = q.array) != null) {      // non-empty
                        long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                        //从工作队列中获得一个任务并返回
                        if ((t = ((ForkJoinTask<?>)
                                  U.getObjectVolatile(a, i))) != null &&
                            q.base == b) {
                            if (ss >= 0) {
                                if (U.compareAndSwapObject(a, i, t, null)) {
                                    q.base = b + 1;
                                    if (n < -1)       // signal others
                                        signalWork(ws, q);
                                    return t;
                                }
                            }
                            else if (oldSum == 0 &&   // try to activate
                                     w.scanState < 0)
                                tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                        }
                        if (ss < 0)                   // refresh
                            ss = w.scanState;
                        r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                        origin = k = r & m;           // move and rescan
                        oldSum = checkSum = 0;
                        continue;
                    }
                    checkSum += b;
                }
              ...
            }
        }
        return null;
    }

先判断当前线程是否处于未激活状态,如果是未激活就尝试激活它,从其他队列中偷取一个任务去执行

final void runTask(ForkJoinTask<?> task) {
            if (task != null) {
            	//标记当前线程处于工作状态
                scanState &= ~SCANNING;
                //执行任务并返回结果
                (currentSteal = task).doExec();
                U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
                execLocalTasks();
                ForkJoinWorkerThread thread = owner;
                if (++nsteals < 0)      // collect on overflow
                    transferStealCount(pool);
                scanState |= SCANNING;
                if (thread != null)
                    thread.afterTopLevelExec();
            }
        }

在这个方法里面首先回去执行窃取的任务,然后执行自己队列里面的任务,在这里面我们最关心的方法就是doExec:

final int doExec() {
        int s; boolean completed;
        if ((s = status) >= 0) {
            try {
            	//执行任务并获取执行结果
                completed = exec();
            } catch (Throwable rex) {
                return setExceptionalCompletion(rex);
            }
            //如果执行成功设置执行成功的标志
            if (completed)
                s = setCompletion(NORMAL);
        }
        return s;
    }

将执行任务的最终任务交给了exec:

protected final boolean exec() {
        result = compute();
        return true;
    }

就是执行我们自己的实现的逻辑了:

 protected Integer compute() {
        int sum = 0;
        //如果任务足够小就计算
        boolean canCompute = (end - start) <= THREAD_HOLD;
        if(canCompute){
            for(int i=start;i<=end;i++){
                sum += i;
            }
        }else{
            int middle = (start + end) / 2;
            CountTask left = new CountTask(start,middle);
            CountTask right = new CountTask(middle+1,end);
            //执行子任务
            left.fork();
            right.fork();
            //获取子任务结果
            int lResult = left.join();
            int rResult = right.join();
            sum = lResult + rResult;
        }
        return sum;
    }

首先根据我们传入的任务的大小做一个划分,如果太大了继续将他们切割变成子任务,否则就直接执行,我们看一下fork是如何执行子任务的:

public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

如果是当前线程池创建的线程就直接放到当前线程的任务队列中:

final void push(ForkJoinTask<?> task) {
            ForkJoinTask<?>[] a; ForkJoinPool p;
            int b = base, s = top, n;
            if ((a = array) != null) {    // ignore if queue removed
                int m = a.length - 1;     // fenced write for task visibility
                U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
                U.putOrderedInt(this, QTOP, s + 1);
                if ((n = s - b) <= 1) {
                    if ((p = pool) != null)
                        p.signalWork(p.workQueues, this);
                }
                else if (n >= m)
                    growArray();
            }
        }

通过原子操作放到任务队列中,并且通知工人来处理它
接下来看一下join方法:

public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

调用doJoin方法,如果执行成功,返回结果:

private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        //得到这个task的运行状态
        return (s = status) < 0 ? s :
        	//判断是否是我们框架的线程
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            //如果当前工作队列的top元素是当前的task,从队列中直接拿出来执行
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && 
            (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }

在这个方法中会首先查看工作队列的顶部元素是否是当前放入的task,如果是就直接执行并返回结果,否则,阻塞到worker轮询到当前的任务再去执行。最后递归返回子任务的处理结果,合并到一起,返回最终的处理结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值