阿里的面试题带你认识ForkJoinPool

我相信大家都用过线程池,比如ExcutorService,比如ThreadPoolExcutor

今天来讲讲ForkJoinPool,它实现于ExcutorService,但又和我们常用的

ThreadPoolExcutor原理不同

前言

随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。

今天,我们就从一道阿里的面试题来开始

题目:如何充分利用多核CPU,计算超大数组中所有整数的和?

解析开始

  • 1.单线程相加?

我们最容易想到就是单线程相加,一个for循环搞定。

  • 2.线程池相加?

如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。

  • 3.其它?

Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^

让我们看看上面是那种方法都如何实现

/**
 * 计算1亿个整数的和
 */
public class ForkJoinPoolTest01 {
    public static void main(String[] args) throws ExecutionException, InterruptedException {
        // 构造数据
        int length = 100000000;
        long[] arr = new long[length];
        for (int i = 0; i < length; i++) {
            arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
        }
        // 单线程
        singleThreadSum(arr);
        // ThreadPoolExecutor线程池
        multiThreadSum(arr);
        // ForkJoinPool线程池
        forkJoinSum(arr);

    }

    private static void singleThreadSum(long[] arr) {
        long start = System.currentTimeMillis();

        long sum = 0;
        for (int i = 0; i < arr.length; i++) {
            // 模拟耗时,本文由公从号“彤哥读源码”原创
            sum += (arr[i]/5*5/5*5/5*5/5*5/5*5);
        }

        System.out.println("sum: " + sum);
        System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));

    }

    private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
        long start = System.currentTimeMillis();

        int count = 8;
        ExecutorService threadPool = Executors.newFixedThreadPool(count);
        List<Future<Long>> list = new ArrayList<>();
        for (int i = 0; i < count; i++) {
            int num = i;
            // 分段提交任务
            Future<Long> future = threadPool.submit(() -> {
                long sum = 0;
                for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
                    try {
                        // 模拟耗时
                        sum += (arr[j]/5*5/5*5/5*5/5*5/5*5);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                return sum;
            });
            list.add(future);
        }

        // 每个段结果相加
        long sum = 0;
        for (Future<Long> future : list) {
            sum += future.get();
        }

        System.out.println("sum: " + sum);
        System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
    }

    private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
        long start = System.currentTimeMillis();

        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        // 提交任务
        ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));
        // 获取结果
        Long sum = forkJoinTask.get();

        forkJoinPool.shutdown();

        System.out.println("sum: " + sum);
        System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
    }

    private static class SumTask extends RecursiveTask<Long> {
        private long[] arr;
        private int from;
        private int to;

        public SumTask(long[] arr, int from, int to) {
            this.arr = arr;
            this.from = from;
            this.to = to;
        }

        @Override
        protected Long compute() {
            // 小于1000的时候直接相加,可灵活调整
            if (to - from <= 1000) {
                long sum = 0;
                for (int i = from; i < to; i++) {
                    // 模拟耗时
                    sum += (arr[i]/5*5/5*5/5*5/5*5/5*5);
                }
                return sum;
            }

            // 分成两段任务,本文由公从号“彤哥读源码”原创
            int middle = (from + to) / 2;
            SumTask left = new SumTask(arr, from, middle);
            SumTask right = new SumTask(arr, middle, to);

            // 提交左边的任务
            left.fork();
            // 右边的任务直接利用当前线程计算,节约开销
            Long rightResult = right.compute();
            // 等待左边计算完毕
            Long leftResult = left.join();
            // 返回结果
            return leftResult + rightResult;
        }
    }
}

~~Garnett偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。~~

~~所以,为了演示ForkJoinPool的牛逼之处,我把每个数都/5*5/5*5/5*5/5*5/5*5了一顿操作,用来模拟计算耗时。~~

来看结果:

sum: 107352457433800662
single thread elapse: 789
sum: 107352457433800662
multi thread elapse: 228
sum: 107352457433800662
fork join elapse: 189

可以看到,ForkJoinPool相对普通线程池还是有很大提升的。

什么是ForkJoinPool?

谈到线程池,很多人会想到Executors提供的一些预设的线程池,比如单线程线程池SingleThreadExecutor,固定大小的线程池FixedThreadPool,但是很少有人会注意到其中还提供了一种特殊的线程池:WorkStealingPool,我们点进这个方法,会看到和其他方法不同的是,这种线程池并不是通过ThreadPoolExecutor来创建的,而是ForkJoinPool来创建的

public static ExecutorService newWorkStealingPool() {
        return new ForkJoinPool
            (Runtime.getRuntime().availableProcessors(),
             ForkJoinPool.defaultForkJoinWorkerThreadFactory,
             null, true);
    }

这两种线程池之间并不是继承关系,而是平级关系:

ThreadPoolExecutor应该都很了解了,就是一个基本的存储线程的线程池,需要执行任务的时候就从线程池中拿一个线程来执行。而ForkJoinPool则不仅仅是这么简单,同样也不是ThreadPoolExecutor的代替品,这种线程池是为了实现“分治法”这一思想而创建的,通过把大任务拆分成小任务,然后再把小任务的结果汇总起来就是最终的结果,和MapReduce的思想很类似

最核心的思想可以这样描述:

if(任务很小){
    直接计算得到结果
}else{
    分拆成N个子任务
    调用子任务的fork()进行计算
    调用子任务的join()合并计算结果
}
  • 1.fork()

fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。

  • 2.join()

join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

ForkJoinPool内部原理-工作窃取

work-stealing(工作窃取)算法

ForkJoinPool 的另一个特性是它使用了work-stealing(工作窃取)算法

线程池内的所有工作线程都尝试找到并执行已经提交的任务,或者是被其他活动任务创建的子任务(如果不存在就阻塞等待)。这种特性使得 ForkJoinPool 在运行多个可以产生子任务的任务,或者是提交的许多小任务时效率更高。尤其是构建异步模型的 ForkJoinPool 时,对不需要合并(join)的事件类型任务也非常适用

在 ForkJoinPool 中,线程池中每个工作线程(ForkJoinWorkerThread)都对应一个任务队列(WorkQueue),工作线程优先处理来自自身队列的任务(LIFO或FIFO顺序,参数 mode 决定),然后以FIFO的顺序随机窃取其他队列中的任务。

ForkJoinPool中的任务

ForkJoinPool 中的任务分为两种:

一种是本地提交的任务(Submission task,如 execute、submit 提交的任务);

另外一种是 fork 出的子任务(Worker task)。

两种任务都会存放在 WorkQueue 数组中,但是这两种任务并不会混合在同一个队列里,ForkJoinPool 内部使用了一种随机哈希算法(有点类似 ConcurrentHashMap 的桶随机算法)将工作队列与对应的工作线程关联起来,Submission 任务存放在 WorkQueue 数组的偶数索引位置,Worker 任务存放在奇数索引位。

实质上,Submission 与 Worker 一样,只不过它被限制只能执行它们提交的本地任务,在后面的源码解析中,我们统一称之为“Worker”。


任务的分布情况如下图:

ForkJoinPool原理

初始化ForkJoinPool

ForkJoinPool pool = ForkJoinPool.commonPool()

public static ForkJoinPool commonPool() {
    // assert common != null : "static init error";
    return common;
}

获取ForkJoinPool很简单,直接调用commonPool()。注意,这个方法是jdk1.8才加的,也是推荐的方法,满足大部分场景。

static{
    //...
    common = java.security.AccessController.doPrivileged
            (new java.security.PrivilegedAction<ForkJoinPool>() {
                public ForkJoinPool run() { return makeCommonPool(); }});
    //...
}

private static ForkJoinPool makeCommonPool() {
    //...
    return new ForkJoinPool(parallelism, factory, handler, LIFO_QUEUE,"ForkJoinPool.commonPool-worker-");
}

common在static{}里创建,调用的是makeCommonPool(),最终调用ForkJoinPool的构造函数。

private ForkJoinPool(int parallelism,
                     ForkJoinWorkerThreadFactory factory,
                     UncaughtExceptionHandler handler,
                     int mode,
                     String workerNamePrefix) {
    this.workerNamePrefix = workerNamePrefix;
    this.factory = factory;
    this.ueh = handler;
    this.config = (parallelism & SMASK) | mode;
    long np = (long)(-parallelism); // offset ctl counts
    this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}

parallelism默认是cpu核心数,ForkJoinPool里线程数量依据于它,但不表示最大线程数,不要等同于ThreadPoolExecutor里的corePoolSize或者maximumPoolSize。

factory是线程工程,不是新东西了,默认实现是

DefaultForkJoinWorkerThreadFactory。

workerNamePrefix是其中线程名称的前缀,默认使用“ForkJoinPool-*”

config保存不变的参数,包括了parallelism和mode,供后续读取。mode可选FIFO_QUEUELIFO_QUEUE,默认是LIFO_QUEUE,具体用哪种,就要看业务。

ctl是ForkJoinPool中最重要的控制字段,将下面信息按16bit为一组封装在一个long中。

  • AC: 活动的worker数量;

  • TC: 总共的worker数量;

  • SS: WorkQueue状态,第一位表示active的还是inactive,其余十五位表示版本号(对付ABA);

  • ID:  这里保存了一个WorkQueue在WorkQueue[]的下标,和其他worker通过字段stackPred组成一个TreiberStack。后文讲的栈顶,指这里下标所在的WorkQueue。

TreiberStack:这个栈的pull和pop使用了CAS,所以支持并发下的无锁操作。

AC和TC初始化时取的是parallelism负数,后续代码可以直接判断正负,为负代表还没有达到目标数量。另外ctl低32位有个技巧可以直接用sp=(int)ctl取得,为负代表存在空闲worker。

线程池缺不了状态的变化,记录字段是runState,具体介绍在后面的“ForkJoinPool状态修改”。

任务ForkJoinTask

ForkJoinPool执行任务的对象是ForkJoinTask,它是一个抽象类,有两个具体实现类RecursiveAction和RecursiveTask。

public abstract class RecursiveAction extends ForkJoinTask<Void> {
    protected abstract void compute();

    public final Void getRawResult() { return null; }

    protected final void setRawResult(Void mustBeNull) { }

    protected final boolean exec() {
        compute();
        return true;
    }
}

public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
    V result;

    protected abstract V compute();

    public final V getRawResult() {
        return result;
    }

    protected final void setRawResult(V value) {
        result = value;
    }

    protected final boolean exec() {
        result = compute();
        return true;
    }
}

ForkJoinTask的抽象方法exec由RecursiveAction和RecursiveTask实现,它被定义为final,具体的执行步骤compute延迟到子类实现。很容易看出RecursiveAction和RecursiveTask的区别,前者没有result,getRawResult返回空,它们对应不需要返回结果和需要返回结果两种场景。

ForkJoinTask里很重要的字段是它的状态status,默认是0,当得出结果时变更为负数,有三种结果:

  • NORMAL

  • CANCELLED

  • EXCEPTIONAL

除此之外,在得出结果之前,任务状态能够被设置为SIGNAL,表示有线程等待这个任务的结果,执行完成后需要notify通知,具体看后文的join。

ForkJoinTask在触发执行后,并不支持其他什么特别操作,只能等待任务执行完成。CountedCompleter是ForkJoinTask的子类,它在子任务协作方面扩展了更多操作。我们聚焦ForkJoinPool主线流程,CountedCompleter相关内容另文再介绍。

WorkQueue

WorkQueue是一个双端队列,它定义在ForkJoinPool类里。

scanState描述WorkQueue当前状态:

  • 偶数表示RUNNING

  • 奇数表示SCANNING

  • 负数表示inactive

stackPred是WorkQueue组成TreiberStack时,保存前者的字段。

ForkJoinPool状态修改

  • STARTED

  • STOP

  • TERMINATED

  • SHUTDOWN

  • RSLOCK‍‍‍‍

  • RSIGNAL

runState记录了ForkJoinPool的运行状态,除了SHUTDOWN是负数,其他都是正数。前面四种不用说了,线程池标准状态流转。在多线程环境修改runState,不能简单想改就改,需要先获取锁,RSLOCK和RSIGNAL就用在这里。

private int lockRunState() {
    int rs;
    return ((((rs = runState) & RSLOCK) != 0 ||
             !U.compareAndSwapInt(this, RUNSTATE, rs, rs |= RSLOCK)) ?
            awaitRunStateLock() : rs);
}

修改前调用lockRunState锁定,检查当前状态,尝试一次使用CAS修改runState为RSLOCK。需要状态变化的机会很少,大多数时间一次就能成功,但不能排除少几率的竞争,这时候进入awaitRunStateLock。

private int awaitRunStateLock() {
    Object lock;
    boolean wasInterrupted = false;
    for (int spins = SPINS, r = 0, rs, ns;;) {
        //1
        if (((rs = runState) & RSLOCK) == 0) {
            if (U.compareAndSwapInt(this, RUNSTATE, rs, ns = rs | RSLOCK)) {
                if (wasInterrupted) {
                    try {
                        Thread.currentThread().interrupt();
                    } catch (SecurityException ignore) {
                    }
                }
                return ns;
            }
        }
        else if (r == 0)
            r = ThreadLocalRandom.nextSecondarySeed();
        else if (spins > 0) {
            r ^= r << 6; r ^= r >>> 21; r ^= r << 7; // xorshift
            if (r >= 0)
                --spins;
        }
        //2
        else if ((rs & STARTED) == 0 || (lock = stealCounter) == null)
            Thread.yield();   // initialization race
        //3
        else if (U.compareAndSwapInt(this, RUNSTATE, rs, rs | RSIGNAL)) {
            synchronized (lock) {
                if ((runState & RSIGNAL) != 0) {
                    try {
                        lock.wait();
                    } catch (InterruptedException ie) {
                        if (!(Thread.currentThread() instanceof
                              ForkJoinWorkerThread))
                            wasInterrupted = true;
                    }
                }
                else
                    lock.notifyAll();
            }
        }
    }
}

在自旋中,第一步,mark1再次尝试修改runState为RSLOCK,成功直接返回。

mark2检查ForkJoinPool初始化情况,这里没有额外多写个变量做锁,直接利用了stealCounter这个原子变量。因为初始化时(后文的externalSubmit),才会对stealCounter赋值。所以当状态不是STARTED或者stealCounter为空时,让出线程等待。

mark3处,线程不会无限制自旋尝试,会利用wait/notify进入阻塞等待。RSIGNAL代替原状态,表示有线程进入了等待,解锁时要处理。在高并发下,这不是一个好的设计,但进入这里的几率很低,作为兜底还是可以的。

private void unlockRunState(int oldRunState, int newRunState) {
    if (!U.compareAndSwapInt(this, RUNSTATE, oldRunState, newRunState)) {
        Object lock = stealCounter;
        runState = newRunState;              // clears RSIGNAL bit
        if (lock != null)
            synchronized (lock) { lock.notifyAll(); }
    }
}

解锁的逻辑就比较简单,如果顺利将状态修改为目标状态,成功大吉。否则表示有别的线程进入了wait,需要调用notifyAll唤醒,重新尝试竞争。

ForkJoinPool最佳实践

(1)最适合的是计算密集型任务

(2)在需要阻塞工作线程时,可以使用ManagedBlocker;

(3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();

总结

(1)ForkJoinPool特别适合于“分而治之”算法的实现;

(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;

(4)ForkjoinPool内部基于“工作窃取”算法实现;

(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

(6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;

(7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

Garnett还会不断的分享技术干货的,希望你们是我最好的观众!

乐于输出干货的Java技术公众号:Garnett的Java之路。公众号内有大量的技术文章、海量视频资源、精美脑图,不妨来关注一下!回复【资料】领取大量学习资源和免费书籍!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值