Java多线程 - Fork/Join

1.什么是 Fork/Join?

Fork/Join 框架 也称为分解/合并框架,Fork/Join 框架的基本思想是分而治之。什么是分而治之?分而治之就是将一个复杂的计算,按照设定的阈值进行分解成多个计算,然后将各个计算结果进行汇总。相应的ForkJoin将复杂的计算当做一个任务。而分解的多个计算则是当做一个子任务。

2.ForkJoin的使用 

使用ForkJoin框架,需要创建一个ForkJoin的任务,而ForkJoinTask是一个抽象类,我们不需要去继承ForkJoinTask进行使用。因为ForkJoin框架为我们提供了RecursiveAction和RecursiveTask。我们只需要继承ForkJoin为我们提供的抽象类的其中一个并且实现compute方法。

class SumTask extends RecursiveTask<Long> {

    static final int THRESHOLD = 100;
    long[] array;
    int start;
    int end;

    SumTask(long[] array, int start, int end) {
    this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {
        if (end - start <= THRESHOLD) {
            // 如果任务足够小,直接计算:
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
            System.out.println(String.format("compute %d~%d = %d", start, end, sum));
            return sum;
        }
        // 任务太大,一分为二:
        int middle = (end + start) / 2;
        System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
        SumTask subtask1 = new SumTask(this.array, start, middle);
        SumTask subtask2 = new SumTask(this.array, middle, end);
        invokeAll(subtask1, subtask2);
        Long subresult1 = subtask1.join();
        Long subresult2 = subtask2.join();
        Long result = subresult1 + subresult2;
        System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
        return result;
    }
}

使用ForkJoinPool进行执行

public static void main(String[] args) throws Exception {
    // 创建随机数组成的数组:
    long[] array = new long[400];
    fillRandom(array);
    // fork/join task:
    ForkJoinPool fjp = new ForkJoinPool(4); // 最大并发数4
    ForkJoinTask<Long> task = new SumTask(array, 0, array.length);
    long startTime = System.currentTimeMillis();
    Long result = fjp.invoke(task);
    long endTime = System.currentTimeMillis();
    System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}

关键代码是fjp.invoke(task)来提交一个Fork/Join任务并发执行,然后获得异步执行的结果。

我们设置任务的最小阀值是100,当提交一个400大小的任务时,在4核CPU上执行,会一分为二,再二分为四,每个最小子任务的执行时间是1秒,由于是并发4个子任务执行,整个任务最终执行时间大约为1秒。

新手在编写Fork/Join任务时,往往用搜索引擎搜到一个例子,然后就照着例子写出了下面的代码:

protected Long compute() {
    if (任务足够小?) {
        return computeDirect();
    }
    // 任务太大,一分为二:
    SumTask subtask1 = new SumTask(...);
    SumTask subtask2 = new SumTask(...);
    // 分别对子任务调用fork():
    subtask1.fork();
    subtask2.fork();
    // 合并结果:
    Long subresult1 = subtask1.join();
    Long subresult2 = subtask2.join();
    return subresult1 + subresult2;
}

很遗憾,这种写法是**错!误!的!**这样写没有正确理解Fork/Join模型的任务执行逻辑。

JDK用来执行Fork/Join任务的工作线程池大小等于CPU核心数。在一个4核CPU上,最多可以同时执行4个子任务。对400个元素的数组求和,执行时间应该为1秒。但是,换成上面的代码,执行时间却是两秒。

这是因为执行compute()方法的线程本身也是一个Worker线程,当对两个子任务调用fork()时,这个Worker线程就会把任务分配给另外两个Worker,但是它自己却停下来等待不干活了!这样就白白浪费了Fork/Join线程池中的一个Worker线程,导致了4个子任务至少需要7个线程才能并发执行。

打个比方,假设一个酒店有400个房间,一共有4名清洁工,每个工人每天可以打扫100个房间,这样,4个工人满负荷工作时,400个房间全部打扫完正好需要1天。

Fork/Join的工作模式就像这样:首先,工人甲被分配了400个房间的任务,他一看任务太多了自己一个人不行,所以先把400个房间拆成两个200,然后叫来乙,把其中一个200分给乙。

紧接着,甲和乙再发现200也是个大任务,于是甲继续把200分成两个100,并把其中一个100分给丙,类似的,乙会把其中一个100分给丁,这样,最终4个人每人分到100个房间,并发执行正好是1天。

如果换一种写法:

// 分别对子任务调用fork():
subtask1.fork();
subtask2.fork();

这个任务就分!错!了!

比如甲把400分成两个200后,这种写法相当于甲把一个200分给乙,把另一个200分给丙,然后,甲成了监工,不干活,等乙和丙干完了他直接汇报工作。乙和丙在把200分拆成两个100的过程中,他俩又成了监工,这样,本来只需要4个工人的活,现在需要7个工人才能1天内完成,其中有3个是不干活的。

其实,我们查看JDK的invokeAll()方法的源码就可以发现,invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。

3.RecursiveTask和RecursiveAction区别

  • RecursiveTask
    通过源码的查看我们可以发现RecursiveTask在进行exec之后会使用一个result的变量进行接受返回的结果。而result返回结果类型是通过泛型进行传入。也就是说RecursiveTask执行后是有返回结果。
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
    private static final long serialVersionUID = 5232453952276485270L;

    /**
     * The result of the computation.
     */
    V result;

    /**
     * The main computation performed by this task.
     * @return the result of the computation
     */
    protected abstract V compute();

    public final V getRawResult() {
        return result;
    }

    protected final void setRawResult(V value) {
        result = value;
    }

    /**
     * Implements execution conventions for RecursiveTask.
     */
    protected final boolean exec() {
        result = compute();
        return true;
    }

}
  • RecursiveAction
    RecursiveAction在exec后是不会保存返回结果,因此RecursiveAction与RecursiveTask区别在与RecursiveTask是有返回结果而RecursiveAction是没有返回结果。
public abstract class RecursiveAction extends ForkJoinTask<Void> {
    private static final long serialVersionUID = 5232453952276485070L;

    /**
     * The main computation performed by this task.
     */
    protected abstract void compute();

    /**
     * Always returns {@code null}.
     *
     * @return {@code null} always
     */
    public final Void getRawResult() { return null; }

    /**
     * Requires null completion value.
     */
    protected final void setRawResult(Void mustBeNull) { }

    /**
     * Implements execution conventions for RecursiveActions.
     */
    protected final boolean exec() {
        compute();
        return true;
    }

}

4.ForkJoin工作窃取(work-stealing) 

工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。工作窃取的运行流程图如下:

fj

那么为什么需要使用工作窃取算法呢?假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争,其缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。

窃取的基本思路就是:当worker自己的任务队列里面没有任务时,就去scan别的线程的队列,把别人的任务拿过来执行。

//ForkJoinPool的成员变量
ForkJoinWorkerThread[] workers;  //worker thread集合
private ForkJoinTask<?>[] submissionQueue; //外部任务队列
private final ReentrantLock submissionLock; 
 
//ForkJoinWorkerThread的成员变量
ForkJoinTask<?>[] queue;   //每个worker线程自己的内部任务队列
 
//提交任务
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {  
    if (task == null)  
        throw new NullPointerException();  
    forkOrSubmit(task);  
    return task;  
} 
 
private <T> void forkOrSubmit(ForkJoinTask<T> task) {  
    ForkJoinWorkerThread w;  
    Thread t = Thread.currentThread();  
    if (shutdown)  
        throw new RejectedExecutionException();  
    if ((t instanceof ForkJoinWorkerThread) &&   //如果当前是worker线程提交的任务,也就是worker执行过程中,分裂出来的子任务,放入worker自己的内部任务队列
        (w = (ForkJoinWorkerThread)t).pool == this)  
        w.pushTask(task);  
    else  
        addSubmission(task);  //外部任务,放入pool的全局队列
}   
 
//worker的run方法
public void run() {  
    Throwable exception = null;  
    try {  
        onStart();  
        pool.work(this);  
    } catch (Throwable ex) {  
        exception = ex;  
    } finally {  
        onTermination(exception);  
    }  
}  
 
final void work(ForkJoinWorkerThread w) {  
    boolean swept = false;                // true on empty scans  
    long c;  
    while (!w.terminate && (int)(c = ctl) >= 0) {  
        int a;                            // active count  
        if (!swept && (a = (int)(c >> AC_SHIFT)) <= 0)  
            swept = scan(w, a);   //核心代码都在这个scan函数里面
        else if (tryAwaitWork(w, c))  
            swept = false;  
    }  
}  
 
//scan的基本思路:从别人的任务队列里面抢,没有,再到pool的全局的任务队列里面去取。
private boolean scan(ForkJoinWorkerThread w, int a) {  
    int g = scanGuard;   
 
    int m = (parallelism == 1 - a && blockedCount == 0) ? 0 : g & SMASK;  
    ForkJoinWorkerThread[] ws = workers;  
    if (ws == null || ws.length <= m)         // 过期检测  
        return false;  
 
    for (int r = w.seed, k = r, j = -(m + m); j <= m + m; ++j) {  
        ForkJoinTask<?> t; ForkJoinTask<?>[] q; int b, i;  
        //随机选出一个牺牲者(工作线程)。  
        ForkJoinWorkerThread v = ws[k & m];  
        //一系列检查...  
        if (v != null && (b = v.queueBase) != v.queueTop &&  
            (q = v.queue) != null && (i = (q.length - 1) & b) >= 0) {  
            //如果这个牺牲者的任务队列中还有任务,尝试窃取这个任务。  
            long u = (i << ASHIFT) + ABASE;  
            if ((t = q[i]) != null && v.queueBase == b &&  
                UNSAFE.compareAndSwapObject(q, u, t, null)) {  
                //窃取成功后,调整queueBase  
                int d = (v.queueBase = b + 1) - v.queueTop;  
                //将牺牲者的stealHint设置为当前工作线程在pool中的下标。  
                v.stealHint = w.poolIndex;  
                if (d != 0)  
                    signalWork();             // 如果牺牲者的任务队列还有任务,继续唤醒(或创建)线程。  
                w.execTask(t); //执行窃取的任务。  
            }  
            //计算出下一个随机种子。  
            r ^= r << 13; r ^= r >>> 17; w.seed = r ^ (r << 5);  
            return false;                     // 返回false,表示不是一个空扫描。  
        }  
        //前2*m次,随机扫描。  
        else if (j < 0) {                     // xorshift  
            r ^= r << 13; r ^= r >>> 17; k = r ^= r << 5;  
        }  
        //后2*m次,顺序扫描。  
        else  
            ++k;  
    }  
    if (scanGuard != g)                       // staleness check  
        return false;  
    else {                                     
        //如果扫描完毕后没找到可窃取的任务,那么从Pool的提交任务队列中取一个任务来执行。  
        ForkJoinTask<?> t; ForkJoinTask<?>[] q; int b, i;  
        if ((b = queueBase) != queueTop &&  
            (q = submissionQueue) != null &&  
            (i = (q.length - 1) & b) >= 0) {  
            long u = (i << ASHIFT) + ABASE;  
            if ((t = q[i]) != null && queueBase == b &&  
                UNSAFE.compareAndSwapObject(q, u, t, null)) {  
                queueBase = b + 1;  
                w.execTask(t);  
            }  
            return false;  
        }  
        return true;                         // 如果所有的队列(工作线程的任务队列和pool的任务队列)都是空的,返回true。  
    }  
}  

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值