Java并发系列(12)——ForkJoin框架源码解析

本文详细介绍了Java并发编程中的ForkJoinPool框架,从基本概念到源码分析,包括ForkJoinPool的工作原理、任务提交、任务执行、线程状态、位运算细节、任务拆分与执行、线程创建与销毁等关键点,旨在帮助读者深入理解ForkJoinPool如何利用分治算法高效处理并发任务。
摘要由CSDN通过智能技术生成

接上一篇《Java并发系列(11)——ThreadPoolExecutor实现原理与手写

9.4 ForkJoinPool

ForkJoin 是 JDK 特地为分治算法实现的一个框架。

9.4.1 demo

先看一个 demo,纯数值计算,从 1 累加到 1000_0000_0000:

package per.lvjc.concurrent.pool.efficiency;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {
   

    private static long n = 1000_0000_0000L;

    private static ForkJoinPool forkJoinPool = new ForkJoinPool();

    public static void main(String[] args) throws ExecutionException, InterruptedException {
   
        long start = System.currentTimeMillis();
        long sum = forkJoinPool.submit(new SumTask(1, n)).get();
        long end = System.currentTimeMillis();
        System.out.println("fork join sum:" + sum + ", cost:" + (end - start));
    }

    private static class SumTask extends RecursiveTask<Long> {
   

        private static final long THRESHOLD = 10000;

        private long begin;
        private long end;

        public SumTask(long begin, long end) {
   
            this.begin = begin;
            this.end = end;
        }

        @Override
        protected Long compute() {
   
            long sum = 0;
            if (end - begin <= THRESHOLD) {
   
                for (long i = begin; i <= end; i++) {
   
                    sum += i;
                }
                return sum;
            }
            long mid = (end + begin) / 2;
            SumTask left = new SumTask(begin, mid);
            SumTask right = new SumTask(mid + 1, end);
            invokeAll(left, right);
            long leftSum = left.join();
            long rightSum = right.join();
            return leftSum + rightSum;
        }
    }
}

基于 ForkJoin 框架,使用分治算法就很简单,只需要做两件事:

  • 定义一个 ForkJoinTask(一般通过继承 RecursiveTask 或 RecursiveAction),实现任务的拆分;
  • 向 ForkJoinPool 提交一个 ForkJoinTask 实例。
9.4.2 初步认识

现在来探讨 ForkJoin 框架背后是怎么工作的。在此之前,先对其有个初步的认识。

ForkJoin 框架主要包含四个角色:

  • ForkJoinPool:ForkJoin 专用线程池;
  • ForkJoinTask:ForkJoinPool 唯一可以接受的任务类型,RecursiveTask 和 RecursiveAction 都是 ForkJoinTask 的子类,如果直接提交 Callable 或 Runnable 也会被自动包装成 ForkJoinTask;
  • WorkQueue:ForkJoin 专用工作队列;
  • ForkJoinWorkerThread:ForkJoin 专用工作线程。

这四个角色在运行时的关系如下图所示:

在这里插入图片描述

几个要点:

  • ForkJoinPool 里面有一个 WorkQueue 数组;
  • 每个 WorkQueue 里面都有一个 ForkJoinTask 数组,存放提交的任务;
  • WorkQueue 分为两类:一类有一个 ForkJoinWorkerThread,另一类没有,提交在这一类 WorkQueue 里的任务只能被其它 WorkQueue 窃取过去执行。

ForkJoin 有两大核心思想:

  • 分治算法;
  • 工作密取:为了充分利用 cpu 资源,一个工作线程执行完自己队列的任务之后,不会空闲,而是从其它队列里寻找任务。
9.4.3 工作线程状态

在 ForkJoinPool 中,工作线程有以下几种状态:

  • running,正在执行任务,并且没有因为任何原因阻塞,getRunningThreadCount 方法可以获得;
  • active,活跃状态,getActiveThreadCount 方法可以获得,细分为两种:
    • busy,正在处理自己队列的任务;
    • scan,正在扫描其它队列的任务;
  • inactive,自己队列的任务处理完了,扫描其它队列任务,仍然没有找到新任务,最终线程挂起。
9.4.4 位运算

ForkJoinPool 中大量使用了位运算,弄清楚这些位运算是读懂 ForkJoin 源码的必要条件。

9.4.4.1 成员变量

主要涉及到以下几个成员变量,int 或 long 变量的各 bit 分别什么作用如下图所示:

在这里插入图片描述

简单解释一下。

ForkJoinPool.config 变量:

  • 高 16 位存储 mode 信息,包括:
    • SHARED_QUEUE:共享队列,没有自己的工作线程,只能被其它线程窃取任务;
    • FIFO_QUEUE:先进先出队列;
    • LIFO_QUEUE:后进先出队列;
  • 低 16 位存储并发度,parallelism size,也可以理解为类似于 ThreadPoolExecutor 的 core pool size,但创建额外线程的逻辑不太一样。

ForkJoinPool.runState 变量:

  • 第 1 位的 1 表示 RSLOCK 状态,runState 变量被锁,其它线程暂时不可以修改;
  • 第 2 位的 1 表示 RSIGNAL 状态,线程阻塞等其它线程释放 RSLOCK;
  • 其它状态同 ThreadPoolExecutor。

WorkQueue.scanState 变量:

  • 最高位(符号位)的 1 表示 inactive,所以负数肯定是 inactive 状态;
  • 接下来 15 位无符号数表示版本号,线程每次从 inactive 状态被重新激活都会加 1;
  • 低 16 位存储了当前 WorkQueue 在 ForkJoinPool 的 WorkQueue[] 中的索引值;
  • 最低位的 1 同时还表示 SCANNING 状态,尽管如此,SCANNING 状态在 1 和 0 之间的变更并不会影响对索引值的获取产生影响,因为有自己工作线程的队列索引值一定是奇数,反之为偶数,所以在获取索引值的时候,其实不需要最低位。

ForkJoinPool.ctl 变量:

  • 最高 16 位存储 active 线程数信息;
  • 后面 16 位存储当前总线程数信息;
  • 低 32 位存储最后一个 inactive 的队列的 scanState,而队列在 inactive 之前都会把上一个 inactive 的队列的 scanState 存储到自己的 stackPred 成员变量,所以所有的 inactive 队列会形成一条链,当然,因为最后 inactive 的在最上面,所以实际上是一个栈。
9.4.4.2 方法

下面看几个方法,深化对以上这些成员变量的理解。

构造方法:

    private ForkJoinPool(int parallelism,
                         ForkJoinWorkerThreadFactory factory,
                         UncaughtExceptionHandler handler,
                         int mode,
                         String workerNamePrefix) {
   
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        // SMASK = 0xffff = 0000_0000_0000_0000_1111_1111_1111_1111
        // 所以并发度只占用了低 16 位,但也足够了,一般不可能有这么多线程
        this.config = (parallelism & SMASK) | mode;
        // 以并发度 15 为例,-15 = [32个1][16个1][1111_1111_1111_0001]
        long np = (long)(-parallelism); // offset ctl counts
        // AC_SHIFT = 48
        // AC_MASK = 0xffffL << 48 = [16个1][48个0]
        // 所以前半部分 = [1111_1111_1111_0001][48个0] & [16个1][48个0]
        //             = [1111_1111_1111_0001][48个0]
        // TC_SHIFT = 32
        // TC_MASK = 0xffffL << 32 = [16个0][16个1][32个0]
        // 所以后半部分 = [16个1][1111_1111_1111_0001][32个0] & [16个0][16个1][32个0]
        //             = [16个0][1111_1111_1111_0001][32个0]
        // 最后前半部分 | 后半部分 = [1111_1111_1111_0001][16个0][32个0]
        //                       | [16个0][1111_1111_1111_0001][32个0]
        //                       = [1111_1111_1111_0001][1111_1111_1111_0001][32个0]
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }

getPoolSize 方法:

    public int getPoolSize() {
   
        // config 低 16 位存储 parallelism size,
        // SMASK = 0xffff = 0000_0000_0000_0000_1111_1111_1111_1111
        // 所以前半部分 = parallelism size(不考虑超过 2^16-1)
        // TC_SHIFT = 32
        // 前面说过 ctl = [active线程数-parallelism size(16位)][当前总线程数-parallelism size(16位)][32位]
        // 所以后半部分 = ctl 右移 32 位再强转 short 去掉高 16 位
        //             = 当前总线程数-parallelism size
        // 最后返回的是:parallelism size + 当前总线程数-parallelism size = 当前总线程数
        return (config & SMASK) + (short)(ctl >>> TC_SHIFT);
    }

getActiveThreadCount 方法:

    public int getActiveThreadCount() {
   
        // AC_SHIFT = 48
        // 同上计算可知
        // r = parallelism size + 活跃线程数 - parallelism size = 活跃线程数
        int r = (config & SMASK) + (int)(ctl >> AC_SHIFT);
        // 因为这里没有对 ctl 加锁,所以如果读取脏数据可能导致负值
        return (r <= 0) ? 0 : r; // suppress momentarily negative values
    }

scan 方法,这个方法暂时只看涉及位运算的部分:

    private ForkJoinTask<?> scan(WorkQueue w, int r) {
   
         //...
                if ((k = (k + 1) & m) == origin) {
       // continue until stable
                    //ss = scanState
                    if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                        oldSum == (oldSum = checkSum)) {
   
                        if (ss < 0 || w.qlock < 0)    // already inactive
                            break;
                        //这里的场景是:
                        //    当前队列已经没有任务了,于是线程去窃取其它队列的任务,
                        //    但是其它队列也没有任务,于是要把当前线程变成 inactive 状态
                        //    再把 active 线程数 - 1
                        // INACTIVE = 1 << 31 = 1[31个0]
                        // 位或得到新的 scanState = 原 scanState 最高位置为 1
                        int ns = ss | INACTIVE;       // try to inactivate
                        //这里是要计算新的 ctl
                        //SP_MASK = 0xffffffffL = [32个0][32个1]
                        //前半部分 = [32个0][32位scanState]
                        //UC_MASK = ~SP_MASK(位反运算) = [32个1][32个0]
                        //AC_UNIT = 0x0001L << 48 = [0000_0000_0000_0001][48个0]
                        //后半部分 = [32个1][32个0]
                        //        & [ctl高16位-1][ctl低48位]
                        //        = [ctl高16位-1][ctl中高16位][32个0]
                        //最后 nc = [32个0][32位scanState]
                        //        | [ctl高16位-1][ctl中高16位][32个0]
                        //        = [ctl高16位-1][ctl中高16位][32位scanState]
                        //这就是 [active线程数-1][总线程数不变][低32位换成当前队列的scanState]
                        long nc = ((SP_MASK & ns) |
                                   (UC_MASK & ((c = ctl) - AC_UNIT)));
                        w.stackPred = (int)c;         // hold prev stack top
                        U.putInt(w, QSCANSTATE, ns);
                        if (U.compareAndSwapLong(this, CTL, c, nc))
                            ss = ns;
                        else
                            w.scanState = ss;         // back out
                    }
                    checkSum = 0;
                }
        //...
    }

signalWork 方法:

final void signalWork(WorkQueue[] ws, WorkQueue q) {
   
        long c; int sp, i; WorkQueue v; Thread p;
        //为什么 ctl < 0 就表示 active 线程不够?
        //因为 ctl < 0,即 ctl 最高 16 位 < 0,
        //而 ctl 最高 16 位 = active 线程数 - parallelism size,
        //所以很明显 active 线程数 < parallelism size
        while ((c = ctl) < 0L) {
                          // too few active
            //为什么 ctl 低 32 位等于 0 就表示没有空闲线程?
            //因为 ctl 低 32 位存储的是上一个 inactive 的队列的 scanState
            //而 scanState 不可能是 0,
            //所以很明显不存在 inactive 线程
            if ((sp = (int)c) == 0) {
                     // no idle workers
                //ADD_WORKER = 1 << 47,即第 48 位上是 1,其它全 0,
                //(c & ADD_WORKER) != 0L 说明 ctl 第 48 位也是 1,
                //也就是说,第 33 ~ 48 这 16 位也是负数,
                //而这 16 位 = 当前总线程数 - parallelism size,
                //所以很明显当前总线程数还不够
                if ((c & ADD_WORKER) != 0L)            // too few workers
                    tryAddWorker(c);
                break;
            }
            if (ws == null)                            // unstarted/terminated
                break;
            //SMASK = 0xffff,sp & SMARK,即取 sp 低 16 位,
            //是 WorkQueue 在线程池中 WorkQueue[] 的索引值
            if (ws.length <= (i = sp & SMASK))         // terminated
                break;
            if ((v = ws[i]) == null)                   // terminating
                break;
            //这里的场景是:
            //    重新 active 一个 inactive 线程。
            //SS_SEQ = 1 << 16,
            //sp + SS_SEQ 也就是在 sp 的第 17 位上 +1,即版本号 +1,
            //INACTIVE = 1 << 31,那么 ~INACTIVE = 0[31个1]
            //最终 vs = 0[15位版本号][sp低16位]
            int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
            int d = sp - v.scanState;                  // screen CAS
            // (UC_MASK & (c + AC_UNIT)),即最高 16 位所表示的 active 线程数 +1,
            //前面说过,WorkQueue.stackPred 变量存储的是前一个 inactive 的队列的 scanState,
            //而 ctl 里面低 32 位存储的是最近 inactive 的队列的 scanState,
            //最近 inactive 的队列已经被 active 了,当然要把前一个 inactive 的队列存到 ctl
            long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
            if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
   
                v.scanState = vs;                      // activate v
                if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
            if (q != null && q.base == q.top)          // no more work
                break;
        }
    }

以上,就是 ForkJoinPool 的源码里面主要流程中涉及到的一些位运算。把这些位运算弄清楚之后,再去看 ForkJoinPool 就容易很多了。

9.4.5 WorkQueue

WorkQueue 里面的成员变量有很多,这里我们只关注其中一部分。

scanState 前面已经讲过,32 位的 int 变量,记录了四个信息:

  • 线程的 inactive 状态;
  • 版本号;
  • 队列索引值;
  • scanning 状态。

ForkJoinTask 数组的存取如下图所示:

在这里插入图片描述

  • array 初始容量 8192;
  • 第一个任务放在 4096,似乎是因为操作系统内存的原因;
  • 8191 的位置放入任务之后,还是会回到 0 的位置;
  • 初始 base = top = 4096;
  • 从上面放入一个任务 top + 1,不会从下面放入任务;
  • LIFO 模式自己线程从上面取走任务 top - 1;
  • FIFO 模式自己线程从下面取走任务 base + 1;
  • 被其它线程从下面窃取任务,base + 1,其它线程不会从上面窃取任务;
  • 数组 size 由 top - base 获得;
  • 从 8191 回到 0 之后,top 和 base 会继续往上加,索引值通过取余获得。
9.4.6 外部提交任务

ForkJoinPool 中的任务有两个来源:

  • 外部提交的大任务;
  • 内部拆分的小任务。

当我们调用 ForkJoinPool 的 submit 方法向线程池中提交一个任务时,发生了什么?这个任务会立即分配一个空闲线程来执行还是会入队?

跟踪 submit 方法:

    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
   
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }

这里没什么,主要是 externalPush:

   final void externalPush(ForkJoinTask<?> task) {
   
        WorkQueue[] ws; WorkQueue q; int m;
        int r = ThreadLocalRandom.getProbe();
        int rs = runState;
       //符合条件的直接往数组里面放
        if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
            (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
            U.compareAndSwapInt(q, QLOCK, 0, 1)) {
   
            ForkJoinTask<?>[] a; int am, n, s;
            if ((a = q.array) != null &&
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
   
                int j = ((am & s) << ASHIFT) + ABASE;
                U.putOrderedObject(a, j, task);
                U.putOrderedInt(q, QTOP
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值