Java Parallel Stream 源码深入解析

提出疑问

  1. 如何转换流类型,怎么实现的?
  2. 如何切分任务的?
  3. 如何合并任务结果的?

任务切分

并行流的底层执行是基于ForkJoin框架的,了解的都知道ForkJoin框架要执行的任务逻辑需要使用者重写,重写内容包括如果切分任务,所以去看ForkJoinTask的实现子类如何实现Compute()方法就可以知道如何进行任务切分的。 在AbstractTask类中实现了核心方法。

其中最核心的方法是compute()方法,定义了计算算法的逻辑:

    /**
    * 决定是否进一步拆分一个任务或直接计算它。
    * 如果直接计算,调用doLeaf并将结果传递给setRawResult。
    * 否则就拆分子任务,分叉一个,继续作为另一个。
	* 该方法的结构是为了在一系列的使用中节约资源。循环在分裂时继续进行其中一个子任务,
    * 以避免深度递归。为了应对可能系统性地偏向于左重或右重拆分的拆分器,我们在循环中交替使用哪个子任务被分叉或继续。
    **/
    @Override
    public void compute() {
        Spliterator<P_IN> rs = spliterator, ls; // right, left spliterators
        //估计rs的元素数量
        long sizeEstimate = rs.estimateSize();
        //获取大小阈值
        long sizeThreshold = getTargetSize(sizeEstimate);
        //
        boolean forkRight = false;
        @SuppressWarnings("unchecked") K task = (K) this;
        // 判断条件
        // 1. 元素数量 > 阈值
        // 2. rs可以进行切割
        while (sizeEstimate > sizeThreshold && (ls = rs.trySplit()) != null) {
            K leftChild, rightChild, taskToFork;
            // 使用ls,rs构造两个新的任务
            task.leftChild  = leftChild = task.makeChild(ls);
            task.rightChild = rightChild = task.makeChild(rs);
            //设置等待数量为1
            task.setPendingCount(1);
            
            // 轮流执行左/右两个任务
            if (forkRight) {
                forkRight = false;
                rs = ls;
                task = leftChild;
                taskToFork = rightChild;// rightChild.fork()
            }
            else { 
                forkRight = true;
                task = rightChild;
                taskToFork = leftChild;//leftChild.fork()
            }
            // 递归
            taskToFork.fork();
            sizeEstimate = rs.estimateSize();
        }
        task.setLocalResult(task.doLeaf());
        task.tryComplete();
    }
复制代码

本质上理解,就是将一个Spliterator进行切分成两部分,然后两部分再进行切分,直到while()条件无法再满足,然后对该部分进行运算,将运算结果保存在节点的LocalResult字段上。这部分就是分支的切分阶段。

结果合并

测试代码:

long count = Stream.of(1, 2, 3, 4, 5).parallel().reduce((x, y) -> x + y + 10).get();
复制代码

AbstractPipeline.evaluate() 方法是并行流与串行流的分叉点:

final <R> R evaluate(TerminalOp<E_OUT, R> terminalOp) {
        assert getOutputShape() == terminalOp.inputShape();
        if (linkedOrConsumed)
            throw new IllegalStateException(MSG_STREAM_LINKED);
        linkedOrConsumed = true;

        return isParallel()
               ? terminalOp.evaluateParallel(this, sourceSpliterator(terminalOp.getOpFlags()))  //并行流执行点
               : terminalOp.evaluateSequential(this, sourceSpliterator(terminalOp.getOpFlags()));
    }
复制代码

ReduceOps.evaluateParallel() 方法是其实现之一

@Override
public <P_IN> R evaluateParallel(PipelineHelper<T> helper,
                                 Spliterator<P_IN> spliterator) {
    return new ReduceTask<>(this, helper, spliterator).invoke().get();
}
复制代码

主要做了两件事:

  • 创建一个ReduceTask任务
  • 任务调用invoke()执行

创建的逻辑没有额外的操作,就是将三个参数赋值到实例变量中。

invoke()是ForkJoinTask的方法,方法这里主要关注invoke() 方法的逻辑:

/**
* Commences performing this task, awaits its completion if
* necessary, and returns its result, or throws an (unchecked)
* {@code RuntimeException} or {@code Error} if the underlying
* computation did so.
*
* @return the computed result
*/
public final V invoke() {
    int s;
    //执行任务
    if ((s = doInvoke() & DONE_MASK) != NORMAL)
        reportException(s);

    // 这里放回的是最终结果
    return getRawResult();
}
复制代码
/**
* Implementation for invoke, quietlyInvoke.
*
* @return status upon completion
*/
private int doInvoke() {
    int s; Thread t; ForkJoinWorkerThread wt;
    return (s = doExec()) < 0 ? s :
    ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
        (wt = (ForkJoinWorkerThread)t).pool.
        awaitJoin(wt.workQueue, this, 0L) :
    externalAwaitDone();
}

==> JDK代码为了性能,牺牲了可读性,这里整理下:

private int doInvoke() {
    s = doExec();
    if (s < 0) return s;
    
	t = Thread.currentThread();
    if (t instanceof ForkJoinWorkerThread) {
    	ForkJoinWorkerThread wt = (ForkJoinWorkerThread)t.pool();
        return wt.awaitJoin(wt.workQueue,this,0L);
    }
    
    return externalAwaitDone();
}
复制代码
    /**
     * If the pending count is nonzero, decrements the count;
     * otherwise invokes {@link #onCompletion(CountedCompleter)}
     * and then similarly tries to complete this task's completer,
     * if one exists, else marks this task as complete.
     */
    public final void tryComplte() {
        CountedCompleter<?> a = this, s = a;
        for (int c;;) {
            if ((c = a.pending) == 0) {
                a.onCompletion(s);
                if ((a = (s = a).completer) == null) {
                    s.quietlyComplete();
                    return;
                }
            }
            else if (U.compareAndSwapInt(a, PENDING, c, c - 1))
                return;
        }
    }
复制代码

ReduceOps.onCompletion() 的重写方法中,如果是不是叶子节点,则合并两个子节点的结果:

@Override
public void onCompletion(CountedCompleter<?> caller) {
    if (!isLeaf()) {
        //如果不是叶子节点,将左右两个子节点的结果合并。
        S leftResult = leftChild.getLocalResult();
        leftResult.combine(rightChild.getLocalResult());
        setLocalResult(leftResult);
    }
    // GC spliterator, left and right child
    super.onCompletion(caller);
}
复制代码

关注里面3个方法:

  • isLeaf()
  • combine()
  • super.onCompletion()

isLeaf()判断是不是叶子节点,只有非叶子节点才有两个子节点,然后才能进行合并任务:

protected boolean isLeaf() {
    return leftChild == null;
}
复制代码

节点在执行完后调用super.onCompletionc(caller) 方法, 字段设为null,为了gc

@Override
public void onCompletion(CountedCompleter<?> caller) {
    spliterator = null;
    leftChild = rightChild = null;
}
复制代码

终点看下combine()方法,这个方法的作用是将结果组合

@Override
public void combine(ReducingSink other) {
    if (!other.empty)
        accept(other.state);
}

// 将本类的state和参数t进行一次运算,由于t是另一个部分运算的结果
// 这里的作用就是两部分进行运算求出结果
@Override
public void accept(T t) {
    if (empty) {
        empty = false;
        state = t;
    } else {
        state = operator.apply(state, t);
    }
}
复制代码

apply(state,t) 运算的逻辑是使用者定义的,对应我们的示例即 .reduce((x, y) -> x + y + 10)

通过这种方法,就可以对所有节点的结果进行两两组合,得出最终结果了。

疑问解答

  1. 如何转换流类型,怎么实现的?
    • 可以使用sequential() parallel() 转换流的类型,源码中是对 sourceStage.parallel = true; 进行标记,最后在启动时候根据和这个标记决定串行or并行。
  2. 如何切分任务的?
    • 底层利用ForkJoindPool框架实现对任务的切分和合并任务,重写compute() ,对Splitertor进行拆分,然后递归调用compute()方法进行不断切分,直到Splitertor无法再切分,执行该部分。
  3. 如何合并任务结果的?
    • 重写onCompletionc(),对每一个非叶子节点,都会将两个子节点的结果通过combine()方法进行组合,然后一直向上仿佛,知道所有节点结果都组合,最后返回根节点的结果。


作者:994🍔
链接:https://juejin.cn/post/7005899099186659335
来源:掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值