Java ForkJoinPool初探

目录

Fork/Join介绍

ForkJoinPool

common pool

ForkJoinTask

CountedCompleter

ManagedBlocker

工作窃取算法


Fork/Join介绍

Fork/Join框架是一种分而治之思想的多线程版实现,可以理解为单机下的map/reduce任务,将一个完整的任务taskA,按照某种条件切割成N个子任务,通过fork将任务分布到多个不同的Thread运行,并通过join来获取taskA的完整运行结果。

Java中的Fork/Join框架实现,是通过ForkJoinPool和ForkJoinTask实现的,ForkJoinPool是一个动态线程池,ForkJoinTask<T>代表一个task,T表示task的返回值类型,ForkJoinTask的实现类有三个:

  • java.util.concurrent.RecursiveTask<T>:对task的返回值敏感的话,可以使用这个实现类
  • java.util.concurrent.RecursiveAction:不需要获取task返回值的场景,可以使用这个实现类
  • java.util.concurrent.CountedCompleter<T>:这个是java8新增的ForkJoinTask的实现类,这个类相较于前两者,最大的特点是当stage complete时,可以触发相应的任务,例如taskA被切割成taskB、taskC,而taskB又被切割成taskB1、taskB2,那么当taskB1执行完,可以触发onCompletionTaskB1,当taskB1和taskB2都执行完时,可以触发onCompletionTaskB。

ForkJoinPool

ForkJoinPool是专门为Fork/Join框架设计的线程池,它与Executor不同,可以动态增减线程池中线程的数量。

ForkJoinPool内部有一个内部类java.util.concurrent.ForkJoinPool.WorkQueue,这个数据结构与work-stealing algorithm(工作窃取算法)是整个java fork/join框架的精髓所在。

在编码过程中,我们可以手动去创建一个ForkJoinPool来运行ForkJoinTask,但java自身也提供了一个静态的commonPool实现,当运行ForkJoinTask但不指定ForkJoinPool时,默认使用全局的commonPool来运行。

public ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode) {
        this(checkParallelism(parallelism),
             checkFactory(factory),
             handler,
             asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
             "ForkJoinPool-" + nextPoolId() + "-worker-");
        checkPermission();
    }
private static int checkParallelism(int parallelism) {
        if (parallelism <= 0 || parallelism > MAX_CAP) //max_cap=0x7fff
            throw new IllegalArgumentException();
        return parallelism;
    }

private ForkJoinPool(int parallelism,
                         ForkJoinWorkerThreadFactory factory,
                         UncaughtExceptionHandler handler,
                         int mode,
                         String workerNamePrefix) {
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.config = (parallelism & SMASK) | mode;
        long np = (long)(-parallelism); // offset ctl counts
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }


public static ForkJoinPool commonPool() {
        // assert common != null : "static init error";
        return common;
    }

ForkJoinPool构造函数的parameter解释:

  • parallelism:并行度,也就是pool中线程的数量,默认等于运行服务器的core数量,注意:并行度有上限,最大值不能超过2^15-1(32767),否则抛出异常
  • factory:用于创建fork/join thread的工厂类,需要注意:不同于Executor的threadFactory,这个factory返回的是java.util.concurrent.ForkJoinWorkerThread类,而不是java.lang.Thread
  • handler:thread的未捕获异常处理器的具体实现,默认是null
  • asyncMode:默认值是false,用于控制WorkQueue的数据结构类型,默认情况下是栈结构,当设置为true时,使用队列结构。

对于WorkQueue的栈结构和队列结构,各自的使用场景:

  1. 栈结构一般用于分而治之思想的任务处理,这种处理方式能够优先运行更细粒度的task,而不是等到所有粗粒度的task运行结束,才执行细粒度task,可以减少空间消耗。
  2. 队列结构一般用于处理event-style的任务,但这种模式下,使用Executor也可以实现,因此一般也不会主动去设置这个参数。

当ForkJoinPool shutdown后或者没有足够的资源运行新提交的task时,再提交task就会抛出RejectedExecutionException 异常。

common pool

common pool与new ForkJoinPool()不同的地方在于:

  1. common pool在调用shutdown或者shutdownNow时,不会修改pool state,也就是说,除非调用System.exit(0),否则common pool会一直接收ForkJoinTask并运行
  2. common pool会占用main线程作为其中一个thread worker

common pool可以通过System.setProperty来控制参数:

  1. java.util.concurrent.ForkJoinPool.common.parallelism:>=0,如果设置为0,表示禁用common pool,但这样做可能导致未执行任务永远不会执行
  2. java.util.concurrent.ForkJoinPool.common.threadFactory:ForkJoinPool.ForkJoinWorkerThreadFactory 的实现类名称,如果设置null,会禁用common pool
  3. java.util.concurrent.ForkJoinPool.common.exceptionHandler:Thread.UncaughtExceptionHandler的类名

 

ForkJoinTask

ForkJoinTask中最重要的方法就是exec()方法,ForkJoinTask在被调度运行时,触发的是doExec()方法

final int doExec() {
        int s; boolean completed;
        if ((s = status) >= 0) {
            try {
                completed = exec();
            } catch (Throwable rex) {
                return setExceptionalCompletion(rex);
            }
            if (completed)
                s = setCompletion(NORMAL);
        }
        return s;
    }

exec()方法负责运行计算任务,同时也负责异常处理和整个task的运行状态,这个方法的返回值是boolean类型,当为true时,task才会被设置为NORMAL结束标志,但需要注意:ForkJoinTask的实现类CountedCompleter的exec方法返回值一直是false,因此对于CountedCompleter来说,整个task的运行状态必须依赖tryComplete()方法来控制。

我们在使用ForkJoinTask的时候,一般都是继承ForkJoinTask的abstract subclass,也就是java.util.concurrent.RecursiveTask<T>、java.util.concurrent.RecursiveAction和java.util.concurrent.CountedCompleter<T>,并重写compute()方法,下面给出一个用于计算\sum_{i=1}^{400}i的累加和demo:

package com.forkjoin;

import com.sun.istack.internal.NotNull;

import java.util.concurrent.*;

/**
 * 研究forkjoinpool的内部源码用
 */
public class ForkJoinPoolAnalyze1 extends RecursiveTask<Integer> {
    private int[] intvals;

    public ForkJoinPoolAnalyze1(@NotNull int[] intvals) {
        this.intvals = intvals;
    }

    public static void main(String[] args) throws NoSuchFieldException {
        int[] intvals = new int[400];
        for (int i = 0; i < 400; i++) {
            intvals[i] = i;
        }
        ForkJoinPoolAnalyze1 forkJoinPoolAnalyze1 = new ForkJoinPoolAnalyze1(intvals);
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        long start = System.currentTimeMillis();
        try {
            ForkJoinTask<Integer> future = forkJoinPool.submit(forkJoinPoolAnalyze1);
            int result = future.join();
            System.out.println("计算结果:" + result + ",总共耗时:" + (System.currentTimeMillis() - start) + " ms");
        } catch (Exception e) {
            System.out.println("出错误啦!!!");
            e.printStackTrace();
        }

    }

    @Override
    protected Integer compute() {
        if (intvals.length <= 5) {
            int result = 0;
            for (int i : intvals) {
//                int x = 1/i; //1/0的问题
                try {
                    TimeUnit.MILLISECONDS.sleep(10);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                result += i;
            }
            System.out.println("ThreadName:" + Thread.currentThread().getName() + ",计算结果:" + result);
            return result;
        } else {
            int size = intvals.length / 5;
            int modval = intvals.length % 5;
            if (modval > 0)
                size += 1;
            ForkJoinPoolAnalyze1[] forkJoinPoolAnalyze1s = new ForkJoinPoolAnalyze1[size];
            for (int i = 0; i < size; i++) {
                int arrsize = 5;
                if (modval > 0 && i == size - 1) {
                    arrsize = modval;
                }
                int[] newIntArrays = new int[arrsize];
                System.arraycopy(this.intvals, i * 5, newIntArrays, 0, arrsize);
                ForkJoinPoolAnalyze1 forkJoinPoolAnalyze1 = new ForkJoinPoolAnalyze1(newIntArrays);
                forkJoinPoolAnalyze1s[i] = forkJoinPoolAnalyze1;
            }
            invokeAll(forkJoinPoolAnalyze1s);
            int result = 0;
            for (ForkJoinPoolAnalyze1 forkJoinPoolAnalyze1 : forkJoinPoolAnalyze1s) {
                result += forkJoinPoolAnalyze1.getRawResult();
            }
            return result;
        }
    }
}

CountedCompleter

这个实现类的最大特点是需要维护一个pending counter,用于记录每个task的当前阻塞情况,CountedCompleter类似于一个树形结构,Main Thread提交的CountedCompleter属于root,特点就是CountedCompleter.completer=null;而由其fork出来的task,comleter=root,以此类推。

protected CountedCompleter(CountedCompleter<?> completer,
                               int initialPendingCount) {
        this.completer = completer; //直接父节点
        this.pending = initialPendingCount;
    }

CountedCompleter中维护一个pendingCount,用于记录当前task的pending数量,只有当this.pendingCount归0后,才认为当前的task运行结束,才有机会触发parent node task的完成,例如root中存在两个fork task,root.pendingCount=2,只有当fork task都执行结束时,才会触发root task的的完成,并将完整的task设置成完成状态,tryComplete()方法的源码如下:

public final void tryComplete() {
        CountedCompleter<?> a = this, s = a;
        for (int c;;) {
            if ((c = a.pending) == 0) {
                a.onCompletion(s);
                if ((a = (s = a).completer) == null) {
                    s.quietlyComplete();
                    return;
                }
            }
            else if (U.compareAndSwapInt(a, PENDING, c, c - 1))
                return;
        }
    }

通过源码可以看到,只有root task执行结束,才认为整个forkjointask执行结束,而root task要执行结束,首先要求fork task要将pendingCount设置为0,并显示的触发一次tryComplete()方法。

注意:pendingCount可以手动设置,因此并不一定要与fork task的数量相等,而要完成一个task,也并非必须通过tryComplete方法,也可以通过complete方法直接完成task

/**
* 可以看到,任意的一个task都可以执行complete来直接完成整个countedcompleter task
*/
public void complete(T rawResult) {
        CountedCompleter<?> p;
        setRawResult(rawResult);
        onCompletion(this);
        quietlyComplete();
        if ((p = completer) != null)
            p.tryComplete();
    }

使用CountedCompleter计算\sum_{i=0}^{12}i的累加和demo:

package com.forkjoin;

import java.util.Random;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @Author: wanglonglong
 * @Description: 使用CountedCompleter计算0~12的累加和
 */
public class CountedCompleterTest1 extends CountedCompleter<Integer> {
    public static AtomicInteger sum_val = new AtomicInteger(0);
    private final int[] intvals;
    private int totalNum = 0;
    private String threadName;

    public CountedCompleterTest1(int[] intvals,CountedCompleterTest1 parentCountedCompleter) {
        super(parentCountedCompleter);
        this.intvals = intvals;
    }

    @Override
    public Integer getRawResult() {
        return this.totalNum;
    }

    @Override
    public void onCompletion(CountedCompleter<?> caller) {
        CountedCompleterTest1 countedCompleterTest1 = (CountedCompleterTest1) caller;
        System.out.println("line[32]---"+this.threadName+".onCompletion("+countedCompleterTest1.threadName+")");
    }

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        int[] intvals = new int[13];
        for (int i = 0; i < 13; i++) {
            intvals[i] = i;
        }
        CountedCompleter<Integer> countedCompleter = new CountedCompleterTest1(intvals,null);
        countedCompleter.invoke();
        System.out.println("line[42]---总和="+CountedCompleterTest1.sum_val.get());
    }

    @Override
    public void compute() {
        this.threadName = Thread.currentThread().getName();
        String caculateStr = "";
        if (intvals.length <= 5) {
            for (int i : intvals) {
                caculateStr += i+"+";
                try {
                    TimeUnit.MILLISECONDS.sleep(new Random().nextInt(10));
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                this.totalNum += i;
            }
            System.out.println("line[59]---"+"ThreadName="+this.threadName+"的计算表达式:"+caculateStr.substring(0,caculateStr.length()-1)+"="+this.totalNum);
            sum_val.addAndGet(this.totalNum);
        } else {
            int size = intvals.length / 5;
            int modval = intvals.length % 5;
            if (modval > 0)
                size += 1;
            for (int i = 0; i < size; i++) {
                int arrsize = 5;
                if (modval > 0 && i == size - 1) {
                    arrsize = modval;
                }
                int[] newIntArrays = new int[arrsize];
                System.arraycopy(this.intvals, i * 5, newIntArrays, 0, arrsize);
                CountedCompleterTest1 countedCompleterTest1 = new CountedCompleterTest1(newIntArrays,this);
                countedCompleterTest1.fork();
            }
            this.setPendingCount(size);
        }
        /*
        * 注意:getPendingCount()是瞬时结果,并不能准确反映真实情况
        * */
        /*System.out.println(this.threadName+" 尝试完成任务,当前pendingCount="+this.getPendingCount()+",parent pendingCount="
                +(this.getRoot()==null?"null":this.getRoot().getPendingCount()));*/
        tryComplete();
    }
}

 

ManagedBlocker

ManagedBlocker是ForkJoinPool的内部类,其源码如下:

public static void managedBlock(ManagedBlocker blocker)
    throws InterruptedException {
    ForkJoinPool p;
    ForkJoinWorkerThread wt;
    Thread t = Thread.currentThread();
    if ((t instanceof ForkJoinWorkerThread) &&
        (p = (wt = (ForkJoinWorkerThread)t).pool) != null) {
        WorkQueue w = wt.workQueue;
        while (!blocker.isReleasable()) {
            if (p.tryCompensate(w)) {
                try {
                    do {} while (!blocker.isReleasable() &&
                                 !blocker.block());
                } finally {
                    U.getAndAddLong(p, CTL, AC_UNIT);
                }
                break;
            }
        }
    }
    else {
        do {} while (!blocker.isReleasable() &&
                     !blocker.block());
    }

ManagedBlocker用于阻塞ForkJoinTask任务,相较于Thread.wait等操作,ManagedBlocker在阻塞时,提供一个补偿机制,会在阻塞的同时,尝试创建新的线程来弥补ForkJoinPool中被占用的线程。

ManagedBlocker的实现类需要实现isReleasable()和block()方法,isReleasable用于判断是否需要阻塞,block执行具体的阻塞和唤醒操作。

demo:

package com.forkjoin;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;

/**
 * @Author: wanglonglong
 * @Description:
 */
public class BlockerTest1 extends RecursiveAction {
    private int[] intvals;
    private long start = System.currentTimeMillis();

    public BlockerTest1(int[] intvals) {
        this.intvals = intvals;
    }

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        int[] intvals = new int[20];
        for (int i = 0; i < 20; i++) {
            intvals[i] = i;
        }
        BlockerTest1 blockerTest1 = new BlockerTest1(intvals);
        ForkJoinPool forkJoinPool = new ForkJoinPool(2);
        forkJoinPool.submit(blockerTest1);
        blockerTest1.get();
    }

    @Override
    protected void compute() {
        if (intvals.length <= 5) {
            int result = 0;
            for (int i : intvals) {
                /*try {
                    TimeUnit.MILLISECONDS.sleep(500);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }*/
                try {
                    ForkJoinPool.managedBlock(new Blocker());
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                result += i;
            }
            /**
             * 使用sleep暂停的输出结果:
             * ThreadName:ForkJoinPool-1-worker-1,计算结果:10,计算耗时:2504
             * ThreadName:ForkJoinPool-1-worker-0,计算结果:85,计算耗时:2504
             * ThreadName:ForkJoinPool-1-worker-1,计算结果:35,计算耗时:5007
             * ThreadName:ForkJoinPool-1-worker-0,计算结果:60,计算耗时:5007
             *
             * 使用blocker暂停的输出结果:
             * ThreadName:ForkJoinPool-1-worker-0,计算结果:85,计算耗时:2506
             * ThreadName:ForkJoinPool-1-worker-2,计算结果:60,计算耗时:2553
             * ThreadName:ForkJoinPool-1-worker-1,计算结果:10,计算耗时:2601
             * ThreadName:ForkJoinPool-1-worker-3,计算结果:35,计算耗时:2601
             *
             */
            System.out.println("ThreadName:" + Thread.currentThread().getName() + ",计算结果:" + result + ",计算耗时:" +(System.currentTimeMillis()-start));
        } else {
            int size = intvals.length / 5;
            int modval = intvals.length % 5;
            if (modval > 0)
                size += 1;
            BlockerTest1[] blockerTest1s = new BlockerTest1[size];
            for (int i = 0; i < size; i++) {
                int arrsize = 5;
                if (modval > 0 && i == size - 1) {
                    arrsize = modval;
                }
                int[] newIntArrays = new int[arrsize];
                System.arraycopy(this.intvals, i * 5, newIntArrays, 0, arrsize);
                BlockerTest1 blockerTest1 = new BlockerTest1(newIntArrays);
                blockerTest1s[i] = blockerTest1;
            }
            invokeAll(blockerTest1s);
        }
    }

    static class Blocker implements ForkJoinPool.ManagedBlocker {
        volatile boolean flag = false;
        @Override
        public boolean block() throws InterruptedException {
            TimeUnit.MILLISECONDS.sleep(500);
            this.flag = true;
            return true;
        }

        @Override
        public boolean isReleasable() {
            return flag;
        }
    }
}

工作窃取算法

work-stealing是fork/join的精髓,本人水平有限,算法的源码实在让人头秃,因此通过各方资料来分析算法的实现原理。

首先是来自国内网站的一篇文章:Java多线程进阶(四三)—— J.U.C之executors框架:Fork/Join框架(1) 原理 - SegmentFault 思否,这篇文章用图例的方式生动的展示了工作窃取算法的工作流程,但注意:文章中存在一些错误,因此只需要大概阅读以下,对工作窃取算法有个大概的认识即可

工作窃取算法是Doug Lea大神的作品,大神发表了一篇有关算法的论文,地址:http://gee.cs.oswego.edu/dl/papers/fj.pdf,这个是对工作窃取算法最权威的解释。

我根据上面两篇资料进行了整理,如果有错误的地方欢迎指正

fork/join框架的核心是work-stealing algorithm(任务窃取算法),算法的核心是Thread Worker Array和WorkQueue Array,其中Thread Worker由ForkJoinPool的defaultThreadFactory来创建,但需要注意:不同于其他的多线程框架,fork/join框架的worker thread是可以回收的

WorkQueue Array主要的作用就是记录所有的Task,task根据产生方式可以分为external task和inner task:

  • external task表示由外部代码提交的任务,例如由用户代码通过ForkJoinPool.submit(ForkJoinTask)提交的任务,此时的ForkJoinTask就是external task
  • inner task表示由ForkJoinPool中的workerThread提交的任务,例如在RecursiveTask.compute方法中,调用了(new MyRecursiveTask()).fork(),此时MyRecursiveTask就是inner task

workQueue Array的偶数位保存的是external task,而奇数位保存的是peer Thread Worker的inner task。

在ForkJoinPool.asyncMode=false(默认值)时,无论是external task还是inner task,都使用LIFO(也就是栈结构)来保存task;当asyncMode=true时,使用FIFO的算法保存,这样其实也很好理解,就是fork/join框架优先要处理粒度更小的task,用来防止workqueue中保存了大量的task才开始计算。

 

在工作窃取算法中,WorkQueue Array中存储着每个Thread独占的WorkQueue,而WorkQueue是一个由数组和其他信息维护的双向链表结构,核心的成员属性是int base和int top,其抽象结构可以表示成如下图:

对于WorkQueue,只有owner thread才有全系push task,但获取task的操作任何thread都可以,这也就是工作窃取名称的由来,但owner提取task的方式和other thread提取task的方式不同。

  1. push task的时候,base不变,top++
  2. owner thread pop task时,base不变,top--
  3. other thread poll(take) task时,top不变,base++,也就是说,任务窃取使用的队列结构

对于external workqueue,只会被poll出数据,由ForkJointWorkerThread从base开始窃取external task并运行;

对于inner task,可以被pop或者poll,其中ForkJointWorkerThread的主要工作就是从奇数位的workqueue(这个workqueue与ForkJointWorkerThread是绑定的,one-to-one)提取task并运行,只有当private innter task workqueue都执行结束后,才尝试使用poll从其他workqueue中poll task。

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值