Executor(五):ForkJoinPool详解 jdk1.8

首先声明这个类在jdk11中有比较大的改动,如果使用的jdk11这篇文章可能对你帮助不大。

ForkJoinPool在1.7引入,它只被用来运行ForkJoinTask的子类任务。这个线程池和其他的线程池的不同之处在于它使用分而治之和工作窃取算法去执行任务。有效的去处理大多数任务能衍生出小任务的问题。笔者也是刚接触ForkJoinPool,这个类比较复杂如果有错误还望指正。

对于ForkJoinPool由于设计比较复杂,所以jdk文档给了很长的说明。本来翻译了一波,但是看到网上有更好的,所以大家先看看jdk的说明,对这个类有一个感性的认识

工作窃取算法

看看维基百科对工作窃取算法的描述:在并行计算中,工作窃取是多线程计算机程序的调度策略。它解决了在具有固定数量的处理器(或内核)的静态多线程计算机上执行动态多线程计算的问题,该计算可以“产生”新的执行线程。它在执行时间,内存使用和处理器间通信方面都非常有效。

一般是一个双端队列和一个工作线程绑定,如下图所示。工作线程从绑定的队列的头部取任务执行,从别的队列(一般是随机)的底部偷取任务。

主要成员变量

ForkJoinPool为了节省内存的使用,将一些信息打包到一个变量中存放。要读懂ForkJoinPool首先要弄清楚里面一些成员变量中存放了什么信息。

工作队列WorkQueue

WorkQueue作为ForkJoinPool里的一个内部类,也是工作窃取算法的核心。它是一个双端队列,使用 @Contented 注解修饰防止伪共享。伪共享状态:缓存系统中是以缓存行(cache line)为单位存储的。缓存行是2的整数幂个连续字节,一般为32-256个字节。最常见的缓存行大小是64个字节。当多线程修改互相独立的变量时,如果这些变量共享同一个缓存行,就会无意中影响彼此的性能,这就是伪共享。

WorkQueue中比较重要的成员变量

scanState:表示线程是否激活,队列是否正在执行任务。

stackPred:使用场景在空闲队列激活或失活。如果当前队列失活则当前队列在工作队列数组中的下标会替换原来存放在ctl低 32上存放的值,原来的会存放在当前队列的stackPred。形成一个空闲队列栈。

hint:记录偷窃自己任务的队列,用于帮助偷窃自己任务的队列执行任务(反偷),方便快速定位小偷。如果没有这个值需要遍历工作队列去寻找小偷队列。

config:存放了队列在队列数组中的索引低15位,和队列的模式(FIFO,FILO)

qlock:外部提交任务使用的锁

owner:绑定的工作线程,如果是共享则为null

//工作偷取队列数组的初始容量。一定是2次幂。至少为4,但是应该更大一些去减少或消除缓存在队列之间的共享。
static final int INITIAL_QUEUE_CAPACITY = 1 << 13;

        /**
      
         * 队列数组的最大值。小于或等于1<<31-数组入口的宽度去确保缺乏
         * 全面的索引计算 。但是定义一个值略小于这个去帮助用户去捕获偷跑
         * 程序要系统饱和前
         */
        static final int MAXIMUM_QUEUE_CAPACITY = 1 << 26; // 64M

        //最高位表示是否激活,17位到31位表示版本号,低16位表示工作队列数组下标
        //最低位为扫描位 1:为正在扫描,0:为正在执行任务 ,奇数最后一位都是1,偶数最后一位都是0
        volatile int scanState;    // versioned, <0: inactive; odd:scanning
        //持堆栈的前身
        int stackPred;             // pool stack (ctl) predecessor
        //偷取数
        int nsteals;               // number of steals
        //随机化和偷取者索引暗示。记录小偷
        int hint;                  // randomization and stealer index hint
        //线程池的索引和模式
        int config;                // pool index and mode
        //等于1为为锁定,小于0为终止,其他为0
        volatile int qlock;        // 1: locked, < 0: terminate; else 0
        //poll的下一个索引槽
        volatile int base;         // index of next slot for poll
        //push的下一个索引槽
        int top;                   // index of next slot for push
        //元素
        ForkJoinTask<?>[] array;   // the elements (initially unallocated)
        //线程池,可能为null
        final ForkJoinPool pool;   // the containing pool (may be null)
        //拥有者,如果是共享则为null
        final ForkJoinWorkerThread owner; // owning thread or null if shared
        volatile Thread parker;    // == owner during call to park; else null
        volatile ForkJoinTask<?> currentJoin;  // task being joined in awaitJoin
        volatile ForkJoinTask<?> currentSteal; // mainly used by helpStealer

ctl

ctl控制中心,里面存放的信息以及位数如下图。这个很重要,一定要清楚里面表示的信息。

1~32位是某个空闲队列的scanState字段。

33~ 48位初始值时线程池的最大并行数对应的负数。也就是在创建新线程时只用判断线程总数符号位是否为1就能知道是否能创建新线程了。

49~64位初始值和33~48一样。它表示的是活跃的线程数。

这个线程池支持的最大线程数为32767,如果超过会抛出异常。这里只支持32767的并行度是因为ctl的组成关系。只有16位用来存放线程数,最高位表示正负,所以只有15位来表示,也就是2的15次方减一。

构造方法

private ForkJoinPool(int parallelism,
                         ForkJoinWorkerThreadFactory factory,
                         UncaughtExceptionHandler handler,
                         int mode,
                         String workerNamePrefix) {
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.config = (parallelism & SMASK) | mode;
        //这里设置为负数是为了得到对应二进制补码时第三十二位为1,
        //这样在进行np << AC_SHIFT) & AC_MASK操作时表示总线程数的值的第十六位为1,则为负数
        long np = (long)(-parallelism); // offset ctl counts
        //ctl低32位为0
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }

public ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode) {
        this(checkParallelism(parallelism),
             checkFactory(factory),
             handler,
             asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
             "ForkJoinPool-" + nextPoolId() + "-worker-");
        checkPermission();
    }

主要方法

ForkJoinPool的实现比较复杂,所以我画了一下一个任务提交后方法大概调用情况。先对任务的提交过程有一个大概的认识 ,我也会根据这个过程一个一个方法的介绍。先看一下工作队列数组的一个分布情况,它的大小一定是2次幂,奇数位和偶数位存放的虽然都是任务队列。但是奇数位是带工作线程的存放fork出的子任务的队列,偶数队列存放的是外部提交的任务。

外部任务提交的方法调用过程

ForkJoinPool.externalPush 

我们先从精简版的任务提交方法开始。ForkJoinPool的invoke,execute,submit都会调用这个方法来进行任务的提交。这个方法之所以成为精简版的任务提交是因为它没有处理线程池的初始化等问题,如果随机到的偶数槽位队列可以提交任务,则就会直接将任务推入队列。否则会调用完整版任务提交方法externalSubmit。

final void externalPush(ForkJoinTask<?> task) {
        //存放工作队列的队列
        WorkQueue[] ws;
        //随机选取的工作队列
        WorkQueue q;
        //m为存放工作队列的队列的长度
        int m;
        //获取随机探针
        int r = ThreadLocalRandom.getProbe();
        //线程池运行状态
        int rs = runState;
        //如果工作队列的队列不为空&&存放工作队列的队列长度大于0且
        //随机到的槽位不为空,随机探针不为0,线程池状态不为0且设置qlock锁从0到1成功
        if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
            (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
            U.compareAndSwapInt(q, QLOCK, 0, 1)) {//获取随机偶数槽位的workQueue
            //选取槽位的队列的数组
            ForkJoinTask<?>[] a;
            //am:数组长度,n:数组使用的数量,s:top指定的下标,
            int am, n, s;
            //如果队列对应的数组不为空且数组长度大于已经使用的空间
            if ((a = q.array) != null &&
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                //计算队列top的偏移量。ASHIFT是每个ForkJoinTask的占用空间
                //ASHIFT*(am&s)也就是现在队列中任务所占用的空间
                //然后加上ABASE就是新加入任务所在的位置。
                int j = ((am & s) << ASHIFT) + ABASE;
                //队列数组j的地方放入task
                U.putOrderedObject(a, j, task);
                //队列TOP加一
                U.putOrderedInt(q, QTOP, s + 1);
                //解锁
                U.putIntVolatile(q, QLOCK, 0);
                //如果选定的工作队列任务先前小于等于1则唤醒工作线程
                if (n <= 1)
                    signalWork(ws, q);
                return;
            }
            U.compareAndSwapInt(q, QLOCK, 1, 0);
        }
        //初始化
        externalSubmit(task);
    }

ForkJoinPool.externalSubmit

完整版外部的任务提交方法。

第一步,如果线程池没有初始化会先进行初始化操作,比如工作队列数组的空间分配还有线程池的状态修改等。

第二步,如果随机的偶数槽位队列不为空,则将任务推入队列并调用signalWork方法唤醒线程。

如果第二步槽位为null,则第三步为这个槽位创建队列后再重复循环。如果发生竞争会重新随机槽位。

private void externalSubmit(ForkJoinTask<?> task) {
        int r;                                    // initialize caller's probe
        //初始化调用线程的探针值,用于计算WorkQueue索引。
        if ((r = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();
            r = ThreadLocalRandom.getProbe();
        }
        for (;;) {
            WorkQueue[] ws; WorkQueue q; int rs, m, k;
            boolean move = false;
            //运行状态小于0,说明线程池已经关闭
            if ((rs = runState) < 0) {
                tryTerminate(false, false);     // help terminate
                throw new RejectedExecutionException();
            }
            //初始化
            else if ((rs & STARTED) == 0 ||     // initialize
                     ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
                int ns = 0;
                //加锁
                rs = lockRunState();
                try {
                    //再次检测有没有启动
                    if ((rs & STARTED) == 0) {
                        //初始化偷窃线程数
                        U.compareAndSwapObject(this, STEALCOUNTER, null,
                                               new AtomicLong());
                        // create workQueues array with size a power of two
                        //创建一个workQueues容量为2的幂次方
                        int p = config & SMASK; // ensure at least 2 slots
                        int n = (p > 1) ? p - 1 : 1;
                        n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                        n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                        workQueues = new WorkQueue[n];
                        ns = STARTED;
                    }
                } finally {
                    //解锁,并设置线程池状态为STARTED
                    unlockRunState(rs, (rs & ~RSLOCK) | ns);
                }
            }
            //随机的偶数槽位,如果对应的队列不为空。 
//如果这里第一次为null,第二次循环到时还是同样的k值,如果探针没有变的话。
            //所以第二次到这里,大概率是有对应的队列这这个槽位了。
            else if ((q = ws[k = r & m & SQMASK]) != null) {
                //加锁是否成功
                if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                    ForkJoinTask<?>[] a = q.array;
                    int s = q.top;
                    boolean submitted = false; // initial submission or resizing
                    try {                      // locked version of push
                        if ((a != null && a.length > s + 1 - q.base) ||
                            (a = q.growArray()) != null) {
                            //ASHIFT是每个ForkJoinTask的大小对应2的多少次幂,
                            // top左移ASHIFT相当于top*ForkJoinTask的大小
                            //下面这一整句的意思就是找到下个任务的偏移量。
                            int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
                            //将任务放到对应的偏移量
                            U.putOrderedObject(a, j, task);
                            //头部+1
                            U.putOrderedInt(q, QTOP, s + 1);
                            submitted = true;
                        }
                    } finally {
                        //解锁
                        U.compareAndSwapInt(q, QLOCK, 1, 0);
                    }
                    if (submitted) {
                        //提交任务成功,唤醒线程。
                        signalWork(ws, q);
                        return;
                    }
                }
                move = true;                   // move on failure
            }
            //判断是否有上锁
            else if (((rs = runState) & RSLOCK) == 0) { // create new queue
                //创建一个新队列
                q = new WorkQueue(this, null);
                //探针
                q.hint = r;
                //共享模式的
                q.config = k | SHARED_QUEUE;
                q.scanState = INACTIVE;
                rs = lockRunState();           // publish index
                //判断是否终结
                if (rs > 0 &&  (ws = workQueues) != null &&
                    k < ws.length && ws[k] == null)
                    ws[k] = q;//将队列放入工作队列数组                 // else terminated
                //解锁
                unlockRunState(rs, rs & ~RSLOCK);
            }
            //如果被另外线程上锁,则会修改探针的值。
            else
                move = true;                   // move if busy
            if (move)
                r = ThreadLocalRandom.advanceProbe(r);
        }
    }

ForkJoinPool.signalWork

任务提交成功后会调用这个方法,它的作用就是激活一个空闲线程或创建一个线程并绑定一个队列在队列数组的奇数槽位。

如果清楚ForkJoinPool中主要成员变量所代表的含义这个方法就可以很容易的理解。它首先去判断是否有空闲的队列也就是通过ctl的低32位,如果没有则会判断是否能在添加线程,可以就会创建。如果有空闲线程则会进行激活。具体实现可以看下面代码:

final void signalWork(WorkQueue[] ws, WorkQueue q) {
        //ctl的值
        long c;
        //sp:ctl的低三十二位,表示等待队列
        int sp, i;
        WorkQueue v;
        Thread p;
        while ((c = ctl) < 0L) {                       // too few active
            //如果没有空闲线程,ctl低32位初始值为0
            if ((sp = (int)c) == 0) {                  // no idle workers
                //如果总线程数最高位为负数则表示可以添加线程
                //因为ctl表示的是最高并行数的负数
                if ((c & ADD_WORKER) != 0L)            // too few workers
                    tryAddWorker(c);
                break;
            }
            //工作队列数组如果为空则说明没有启动或者已经终止
            if (ws == null)                            // unstarted/terminated
                break;
            //取低十六位赋值给i,如果大于工作队列数组的长度则说明终止了
            if (ws.length <= (i = sp & SMASK))         // terminated
                break;
            //如果i下标为null,说明在终止。低16位存放着最近被灭活的队列
            if ((v = ws[i]) == null)                   // terminating
                break;
            //增加版本号,避免ABA问题
            int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState

            int d = sp - v.scanState;                  // screen CAS
            //增加活跃线程数,修改低32位为刚激活的队列的stackPred
            long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
            if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
                //赋值新版本号
                v.scanState = vs;                      // activate v
                if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
            //如果q中没有任务了就会跳出循环
            if (q != null && q.base == q.top)          // no more work
                break;
        }
    }

创建线程方法

 private void tryAddWorker(long c) {
        boolean add = false;
        do {
            //活跃线程数和线程总数加一
            long nc = ((AC_MASK & (c + AC_UNIT)) |
                       (TC_MASK & (c + TC_UNIT)));
            if (ctl == c) {
                int rs, stop;                 // check if terminating
                if ((stop = (rs = lockRunState()) & STOP) == 0)
                    add = U.compareAndSwapLong(this, CTL, c, nc);
                unlockRunState(rs, rs & ~RSLOCK);
                if (stop != 0)
                    break;
                if (add) {
                    createWorker();
                    break;
                }
            }
            //判断第48位是否为1,也就是线程总数是否为负数。
            //ctl中线程总数应该是对应线程并行数的负数。
            //所以为负数应该是可以继续添加线程的
        } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
    }
private boolean createWorker() {
        ForkJoinWorkerThreadFactory fac = factory;
        Throwable ex = null;
        ForkJoinWorkerThread wt = null;
        try {
            //创建新线程并注册
            if (fac != null && (wt = fac.newThread(this)) != null) {
                wt.start();
                return true;
            }
        } catch (Throwable rex) {
            ex = rex;
        }
        //从工作队列数组移除
        deregisterWorker(wt, ex);
        return false;
    }

 上面的两个方法理解理解不是很难,最后会调用registerWorker方法来进行线程和队列的绑定并注册到工作队列数组中。

final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
        UncaughtExceptionHandler handler;
        //设置为守护线程,这样保证用户线程都已释放的情况下关闭虚拟机.
        wt.setDaemon(true);                           // configure thread
        if ((handler = ueh) != null)
            wt.setUncaughtExceptionHandler(handler);
        //设置所属线程池和所属队列
        WorkQueue w = new WorkQueue(this, wt);
        int i = 0;                                    // assign a pool index
        //队列模式 先进先出,先进后出,共享
        int mode = config & MODE_MASK;
        //加锁
        int rs = lockRunState();
        try {
            WorkQueue[] ws; int n;                    // skip if no array
            //如果工作队列数组为空就跳过
            if ((ws = workQueues) != null && (n = ws.length) > 0) {
                int s = indexSeed += SEED_INCREMENT;  // unlikely to collide
                int m = n - 1;
                //去的一个奇数下标
                i = ((s << 1) | 1) & m;               // odd-numbered indices
                //如果产生碰撞
                if (ws[i] != null) {                  // collision
                    int probes = 0;                   // step by approx half n
                    int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
                    while (ws[i = (i + step) & m] != null) {

                        if (++probes >= n) {
                            //说明已经进行了n次尝试还是没有找到没有碰撞的点,则
                            //进行数组扩容
                            workQueues = ws = Arrays.copyOf(ws, n <<= 1);
                            m = n - 1;
                            probes = 0;
                        }
                    }
                }
                //使用的随机种子
                w.hint = s;                           // use as random seed
                //存放了队列在队列数组中的索引低15位,和队列的模式
                w.config = i | mode;
                //scanState设置为当前下标奇数值
                w.scanState = i;                      // publication fence
                //新队列设置于i处
                ws[i] = w;
            }
        } finally {
            //解锁
            unlockRunState(rs, rs & ~RSLOCK);
        }
        //线程名称
        wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
        return w;
    }

任务执行

当任务提交到队列,并已经创建了线程和队列后,线程就会启动。调用ForkJoinWorkerThread.run方法,这个方法中会调用ForkJoinPool的runWorker方法。当然run方法中也会处理一些异常,解绑等操作,读者可以自行查看。

final void runWorker(WorkQueue w) {
        //分配数组空间
        w.growArray();                   // allocate queue
        //hint是一个随机数
        int seed = w.hint;               // initially holds randomization hint
        int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
        for (ForkJoinTask<?> t;;) {
            if ((t = scan(w, r)) != null)
                w.runTask(t);
            else if (!awaitWork(w, r))
                break;
            r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
        }
    }

final void runTask(ForkJoinTask<?> task) {
            if (task != null) {
                //最低位设置为0,表示在运行任务
                scanState &= ~SCANNING; // mark as busy
                //设置currentSteal当前执行的偷窃任务
                (currentSteal = task).doExec();
                U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
                execLocalTasks();
                ForkJoinWorkerThread thread = owner;
                //增加偷取数,如果偷取数溢出则将这个偷取数加到线程池的偷取数
                if (++nsteals < 0)      // collect on overflow
                    transferStealCount(pool);
                //运行完任务将最低位设置为1.
                scanState |= SCANNING;
                if (thread != null)
                    thread.afterTopLevelExec();
            }
        }

在runWorker方法中首先会进行数组的扩容,因为前面创建队列时并没有给队列的数组分配空间。在执行自己队列任务前会去使用scan方法去偷取任务,如果偷取到任务则执行偷取的任务后然后再执行自己队列中的任务。

ForkJoinPool.scan

扫描整个队列连续出现两次扫描的checkSum的值相同,说明所有的队列都是空的了 需要去灭活当前的队列。因为两次checkSum的值相同说明两次都便利了所有的队列的base 也就是都是线性的增加k的值,如果有的队列有元素发生竞争失败了会随机移动下标, 很大概率不会形成两次一样checkSum的。

如果scan没有扫描到任务会将这个队列失活,并放入将队列的scanState字段方法ctl的低32位,替换原来的值并将原来的值放入当前队列的stackPred字段构成一个栈。scan没有扫描到任务返回后,runWork方法会调用awaitWork方法阻塞线程。

private ForkJoinTask<?> scan(WorkQueue w, int r) {
        //m:任务队列数组的长度-1
        WorkQueue[] ws; int m;
        if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
            //初始是一个非负的值
            int ss = w.scanState;                     // initially non-negative
            for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
                WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
                int b, n; long c;
                //k是一个随机的数,如果不等于null则
                if ((q = ws[k]) != null) {
                    //队列中存在任务
                    if ((n = (b = q.base) - q.top) < 0 &&
                        (a = q.array) != null) {      // non-empty
                        //计算base的偏移量
                        long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                        //取出base下标的任务
                        if ((t = ((ForkJoinTask<?>)
                                  U.getObjectVolatile(a, i))) != null &&
                            q.base == b) {
                            //ss初始时为非负数,(队列在队列数组中的下标)
                            if (ss >= 0) {
                                //将base下标处的任务置为null
                                if (U.compareAndSwapObject(a, i, t, null)) {
                                    //更新队列下标
                                    q.base = b + 1;
                                    //队列中不止一个元素,则唤醒其他线程
                                    if (n < -1)       // signal others
                                        signalWork(ws, q);
                                    //返回任务
                                    return t;
                                }
                            }
                            else if (oldSum == 0 &&   // try to activate
                                     w.scanState < 0)
                                //如果scanState为负数且oldsum为0
                                //scanState什么时候会变为负数,在队列失活的时候
                                //尝试去激活ctl低32位的队列
                                tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                        }
                        if (ss < 0)                   // refresh
                            //刷新ss的值,避免被其他线程修改了未更新
                            ss = w.scanState;
                        //发生竞争随机移动
                        r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                        origin = k = r & m;           // move and rescan
                        oldSum = checkSum = 0;
                        //继续扫描
                        continue;
                    }
                    //如果数组为空,则会checkSum的值会加上队列q的base
                    checkSum += b;
                }
                //能到这里说明ws[k]为null或为空或出现了竞争,
                // k线性加1,直到发现已经从一个origin转满了一圈或n圈.
                if ((k = (k + 1) & m) == origin) {    // continue until stable
                    条件:scanState表示活跃,或者满足当前线程工作队列w的ss未改变,
                    // oldSum依旧等于最新的checkSum(校验和未改变)
                    if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                        oldSum == (oldSum = checkSum)) {
                        //能进入这里说明w[k]不为null,而是队列为空,所以需要灭活
                        //ss小于0说明队列被灭活了,队列的qlock小于0说明已经终止了
                        if (ss < 0 || w.qlock < 0)    // already inactive
                            break;
                        int ns = ss | INACTIVE;       // try to inactivate
                        //将低三十二位替换成scanState 活跃线程减1
                        long nc = ((SP_MASK & ns) |
                                   (UC_MASK & ((c = ctl) - AC_UNIT)));
                       //将先前的栈顶替换存放在新栈顶的stackPred上
                        w.stackPred = (int)c;         // hold prev stack top
                        //将w的scanState设置为新的值,和ctl的低三十二位一样
                        U.putInt(w, QSCANSTATE, ns);
                        if (U.compareAndSwapLong(this, CTL, c, nc))
                            //CAS成功,将ss更新为新值
                            ss = ns;
                        else
                            //CAS失败,还原
                            w.scanState = ss;         // back out
                    }
                    checkSum = 0;
                }
            }
        }
        return null;
    }

到这里任务的提交和执行涉及到的主要方法都解读了一遍。看到这可能会有疑问,那工作窃取算法是怎么运用的?

这个时候就需要介绍一个ForkJoinTask,它是一个抽象类,但是一般使用fork/join时提交的任务也不是直接继承它。而是继承RecursiveTask,RecursiveAction还有CountedCompleter(它们各自的不同点读者可以自行研究)。这些方法中有一个exec方法,这个方法会在doExec方法中调用。在exec会调用compute方法,所以一般继承RecursiveTask方法需要实现compute方法,在这个方法中将任务进行拆分成更小的子任务,通过调用fork来实现任务提交。然后调用join方法等待任务的执行完毕。工作窃取算法的运用就是在这,我们可以看看join方法的实现。

ForkJoinTask.join

public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);//可能被取消或发生异常
        return getRawResult();
    }

private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
        //是否已执行完
            //是 直接返回任务状态
            //否 当前线程是否是ForkJoinWorkerThread
                    //是 执行workQueue的tryUnPush方法和doExec方法。这里的意思是移除在top的当前任务,然后自己主动执行
                        //移除成功 返回任务状态
                        //移除失败 调用awaitJoin方法
                    //否 执行externalAwaitDone
    }

final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
        int s = 0;
        if (task != null && w != null) {
            //将当前任务变为队列的正在join的任务,先前的放到task的pervJoin,形成一个栈。
            ForkJoinTask<?> prevJoin = w.currentJoin;
            U.putOrderedObject(w, QCURRENTJOIN, task);
            CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
                (CountedCompleter<?>)task : null;
            for (;;) {
                //小于0说明完成
                if ((s = task.status) < 0)
                    break;
                //CountedCompleter任务由helpComplete来完成join
                if (cc != null)
                    helpComplete(w, cc, 0);
                //如果队列为空或 执行任务没有成功则会去帮助偷窃.
                        //执行失败说明任务被偷了。
                //tryRemoveAndExec任务执行成功则会返回false
                //在当前队列任务执行完了或者
                    // 或者(在队列中没有找到这个任务且任务没有执行)
                //则这个任务是被偷了,偷窃任务的可能是在join。
                //所以去帮助偷窃者执行他的任务。
                else if (w.base == w.top || w.tryRemoveAndExec(task))
                    helpStealer(w, task);
                //如果任务执行成功则会跳出循环
                if ((s = task.status) < 0)
                    break;
                long ms, ns;
                if (deadline == 0L)
                    ms = 0L;
                else if ((ns = deadline - System.nanoTime()) <= 0L)
                    break;
                else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                    ms = 1L;
                //尝试补偿,在里面有进行
                if (tryCompensate(w)) {
                    //等待
                    task.internalWait(ms);
                    //活跃线程加1
                    U.getAndAddLong(this, CTL, AC_UNIT);
                }
            }
            //还原当前队列正在join的任务
            U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
        }
        //返回任务的状态
        return s;
    }

在join方法中会调用dojoin方法,在doJoin中如果任务没有执行,会调用awaitJoin方法会调用tryRemoveAndExec去自己队列中寻找这个任务。

final boolean tryRemoveAndExec(ForkJoinTask<?> task) {
            // a:队列数组;m:队列最大下标 ;s:top,b:base;n = top - base
            ForkJoinTask<?>[] a; int m, s, b, n;
            if ((a = array) != null && (m = a.length - 1) >= 0 &&
                task != null) {
                //如果队列中有任务。
                while ((n = (s = top) - (b = base)) > 0) {
                    for (ForkJoinTask<?> t;;) {      // traverse from s to b
                        //偏移量 第一次为top后面则递减
                        long j = ((--s & m) << ASHIFT) + ABASE;
                        //获取的偏移量的任务为空则返回。
                        if ((t = (ForkJoinTask<?>)U.getObject(a, j)) == null)
                            return s + 1 == top;     // shorter than expected
                        //如果任务是给定的任务
                        else if (t == task) {
                            boolean removed = false;
                            //这个任务正好是在栈顶top的位置,则直接移除任务且修改top的值
                            if (s + 1 == top) {      // pop
                                if (U.compareAndSwapObject(a, j, task, null)) {
                                    U.putOrderedInt(this, QTOP, s);
                                    removed = true;
                                }
                            }
                            //如果队列base没有变则将一个空任务放置在原来的偏移量位置。
                            //这个空任务的状态是NORMAL
                            else if (base == b)      // replace with proxy
                                removed = U.compareAndSwapObject(
                                    a, j, task, new EmptyTask());
                            //如果被移除会执行任务
                            if (removed)
                                task.doExec();
                            break;
                        }
                        //如果任务已经被执行,且任务是在队列的top,
                        //则将任务对应的偏移量置为null,top减1
                        else if (t.status < 0 && s + 1 == top) {
                            if (U.compareAndSwapObject(a, j, t, null))
                                U.putOrderedInt(this, QTOP, s);
                            break;                  // was cancelled
                        }
                        //如果任务队列遍历完,则返回
                        if (--n == 0)
                            return false;
                    }
                    //任务已经执行
                    if (task.status < 0)
                        return false;
                }
            }
            return true;
        }

tryRemoveAndExec如果找到这个任务会直接执行,然后用一个空任务放入原来的位置。如果没有找到这个任务说明任务被某个队列线程偷取了,会调用helpStealer方法去寻找这个小偷。在helpStealer中只会遍历奇数槽位的队列,因为也只有奇数槽位的队列才会有线程去偷取任务。如果小偷没有执行到自己队列的任务,会帮小偷执行任务。如果自己队列有任务没有执行完会退出方法,然后会进行一次补偿后阻塞线程等待任务完成唤醒。

 private void helpStealer(WorkQueue w, ForkJoinTask<?> task) {
        WorkQueue[] ws = workQueues;
        int oldSum = 0, checkSum, m;
        if (ws != null && (m = ws.length - 1) >= 0 && w != null &&
            task != null) {
            do {                                       // restart point
                checkSum = 0;                          // for stability check
                ForkJoinTask<?> subtask;
                WorkQueue j = w, v;                    // v is subtask stealer
                descent: for (subtask = task; subtask.status >= 0; ) {
                    //确保h为奇数,k每次增加2,h+k则是每次都为奇数
                    for (int h = j.hint | 1, k = 0, i; ; k += 2) {
                        if (k > m)                     // can't find stealer
                            break descent;
                        //从队列数组中取出一个
                        if ((v = ws[i = (h + k) & m]) != null) {
                            //如果这个队列正在处理subtask,说明任务被这个队列偷了。
                            if (v.currentSteal == subtask) {
                                //被偷取任务的就记住了这个小偷。然后跳出循环
                                j.hint = i;
                                break;
                            }
                            checkSum += v.base;
                        }
                    }
                    //v就是盗窃者
                    for (;;) {                         // help v or descend
                        ForkJoinTask<?>[] a; int b;
                        checkSum += (b = v.base);
                        ForkJoinTask<?> next = v.currentJoin;
                        //如果被偷的任务执行完了或者被偷任务的现在join的任务不是subsask
                        //或者偷窃者当前的偷窃任务不是subtask则退出循环
                        if (subtask.status < 0 || j.currentJoin != subtask ||
                            v.currentSteal != subtask) // stale
                            //退出循环
                            break descent;
                        //如果v的队列为空则将v队列任务正在join的任务
                        //设置为subTask,寻找v队列的任务的偷窃者
                        if (b - v.top >= 0 || (a = v.array) == null) {
                            if ((subtask = next) == null)
                                break descent;
                            j = v;
                            break;
                        }
                        //取出盗贼base偏移量的任务,
                        int i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                        ForkJoinTask<?> t = ((ForkJoinTask<?>)
                                             U.getObjectVolatile(a, i));
                        if (v.base == b) {
                            if (t == null)             // stale
                                break descent;
                            //如果CAS操作base偏移量任务置为null成功,则
                            //将base+1,将w正在偷窃的任务修改为刚刚从base获取得到的任务
                            //然后执行w自己的任务
                            if (U.compareAndSwapObject(a, i, t, null)) {
                                v.base = b + 1;
                                ForkJoinTask<?> ps = w.currentSteal;
                                int top = w.top;
                                do {
                                    U.putOrderedObject(w, QCURRENTSTEAL, t);
                                    t.doExec();        // clear local tasks too
                                } while (task.status >= 0 &&
                                         w.top != top &&
                                         (t = w.pop()) != null);
                                //执行完自己的任务后就将当前偷窃的任务设置为先前的。
                                U.putOrderedObject(w, QCURRENTSTEAL, ps);
                                //如果自己队列还有任务则退出帮助
                                if (w.base != w.top)
                                    return;            // can't further help
                            }
                        }
                    }
                }
                //如果任务还么有执行,说明在join.且连续两次遍历整个ws是不同的。
            } while (task.status >= 0 && oldSum != (oldSum = checkSum));
        }
    }

介绍到这里ForkJoinPool的大概机制应该能了解清楚。ForkJoinPool的ManagedBlocker和补偿机制,ForkJoinTask对异常的记录没有做说明。ForkJoinPool设计比较复杂,想要完全弄清楚需要一定时间。

感谢阅读,希望对你又帮助。

参考资料:

jdk文档

java - ForkJoin框架之ForkJoinPool - 个人文章 - SegmentFault 思否

JUC源码分析-线程池篇(四):ForkJoinPool - 1 - 简书

JUC源码分析-线程池篇(五):ForkJoinPool - 2 - 简书

  • 3
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值