【并发编程】ForkJoin线程池

文章探讨了处理CPU密集型任务的不同方法,包括使用并行流、ExecutorService多线程以及ForkJoinPool的Fork/Join框架。通过示例代码展示了如何使用这些技术来提高计算效率,如计算1至10000000正整数之和。ForkJoinPool利用工作窃取算法,能有效地拆分和合并任务,尤其适合计算密集型任务。
摘要由CSDN通过智能技术生成

一、使用场景

用于CPU密集型的任务,通过把任务进行拆分,拆分成多个小任务去执行,然后小任务执行完毕后再把每个小任务执行的结果合并起来,这样就可以节省时间。

CPU密集型(CPU-bound):CPU密集型也叫计算密集型,指的是系统的硬盘、内存性能相对CPU要好很多,此时,系统运作大部分的状况是CPU Loading 100%,CPU要读/写I/O(硬盘/内存),I/O在很短的时间就可以完成,而CPU还有许多运算要处理,CPU Loading很高。
例如:大部份时间用来做计算、逻辑判断等CPU动作的程序称之CPU bound。
线程数一般设置为:线程数 = CPU核数+1 (现代CPU支持超线程)

IO密集型(I/O bound):IO密集型指的是系统的CPU性能相对硬盘、内存要好很多,此时,系统运作,大部分的状况是CPU在等I/O (硬盘/内存) 的读/写操作,此时CPU Loading并不高。
例如:读取本地文件、读取redis缓存、读取数据库等操作
线程数一般设置为:线程数 = ((线程等待时间+线程CPU时间)/线程CPU时间 )* CPU数目

二、简单使用

问题:计算1至10000000的正整数之和。
方案一:for循环解决

 public static void main(String[] args) {
        long sum = 0;
        long start = System.currentTimeMillis();
        for (int i = 1; i <= 10000000;i++) {
           sum += i;
        }
        System.out.println("结果为:" + sum);
        System.out.println("耗时为:" + (System.currentTimeMillis() - start));
}        

在这里插入图片描述

方案二:采用并行流(JDK8以后)

public static void main(String[] args) {
        long start = System.currentTimeMillis();
        long sum = LongStream.rangeClosed(0, 10000000L).parallel().reduce(0, Long::sum);
        System.out.println("结果为:" + sum);
        System.out.println("耗时为:" + (System.currentTimeMillis() - start));

    }
}    

在这里插入图片描述

方案三:ExecutorService多线程方式实现

package concurrency.threadpool;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.stream.LongStream;

public class ExecutorTest {

    int availableProcessors = Runtime.getRuntime().availableProcessors();
    ExecutorService executorService = Executors.newFixedThreadPool(availableProcessors);

    public static void main(String[] args) {

       long[] nums = LongStream.rangeClosed(1, 10000000L).toArray();
       long start = System.currentTimeMillis();
       System.out.println("结果为:" + new ExecutorTest().sumVal(nums));
       System.out.println("耗时为:" + (System.currentTimeMillis() - start));
    }

    private static class SumTask implements Callable<Long> {

        private long[] nums;
        private int from;
        private int to;

        public SumTask(long[] nums, int from, int to) {
            this.nums = nums;
            this.from = from;
            this.to = to;
        }

        @Override
        public Long call() throws Exception {
            long sum = 0;
            for (int i = from; i <= to; i++) {
                sum += nums[i];
            }
            return sum;
        }
    }

    private long sumVal(long[] nums) {
        List<Future<Long>> results = new ArrayList<>();
        int part = nums.length / availableProcessors;
        for (int i = 0; i < availableProcessors; i++) {
            int from = i * part;
            int to = (i == availableProcessors - 1) ? nums.length - 1 : (i + 1) * part - 1;
            results.add(executorService.submit(new SumTask(nums, from, to)));
        }
        long sum = 0;
        for (Future<Long> future : results) {
            try {
                sum += future.get();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (ExecutionException e) {
                e.printStackTrace();
            }
        }
        return sum;
    }



}

在这里插入图片描述

方案四:采用ForkJoinPool(Fork/Join)

package concurrency.threadpool;

import java.util.concurrent.*;
import java.util.stream.LongStream;

public class ForkJoinTest {

    int availableProcessors = Runtime.getRuntime().availableProcessors();
    ForkJoinPool pool = new ForkJoinPool();

    public static void main(String[] args) {

        long[] nums = LongStream.rangeClosed(1, 10000000).toArray();
        long start = System.currentTimeMillis();
        System.out.println("结果为:" + new ForkJoinTest().sumVal(nums));
        System.out.println("耗时为:" + (System.currentTimeMillis() - start));
    }

    private static class SumTask extends RecursiveTask<Long> {

        private long[] nums;
        private int from;
        private int to;

        public SumTask(long[] nums, int from, int to) {
            this.nums = nums;
            this.from = from;
            this.to = to;
        }

        @Override
        protected Long compute() {
            // 当需要计算的数字个数小于6时,直接采用for loop方式计算结果
            if (to - from < 6) {
                long sum = 0;
                for (int i = from; i <= to; i++) {
                    sum += nums[i];
                }
                return sum;
            } else { // 否则,把任务一分为二,递归拆分(注意此处有递归)到底拆分成多少分 需要根据具体情况而定
                int middle = (from + to) / 2;
                SumTask taskLeft = new SumTask(nums, from, middle);
                SumTask taskRight = new SumTask(nums, middle + 1, to);
                taskLeft.fork();
                taskRight.fork();
                return taskLeft.join() + taskRight.join();
            }

        }
    }

    private long sumVal(long[] nums) {
        Long result = pool.invoke(new SumTask(nums, 0, nums.length - 1));
        pool.shutdown();
        return result;
    }

}

在这里插入图片描述
总结:
1.ForkJoinPool 不是为了替代 ExecutorService,而是它的补充,在某些应用场景下性能比 ExecutorService 更好(例:从数据库拉取了千亿万的数据到本地,然后进行排序 )。
2. ForkJoinPool 主要用于实现“分而治之”的算法,特别是分治之后递归调用的函数,例如 quick sort 等。
3. ForkJoinPool 最适合的是计算密集型的任务,如果存在 I/O,线程间同步,sleep() 等会造成线程长时间阻塞的情况时,最好配合使用 ManagedBlocker。

三、整体流程

1.任务入队
在这里插入图片描述

task1还能继续拆分,则调用fork方法进行拆分,
在这里插入图片描述

2.任务执行
worker-thread1的task1执行完成,出队
在这里插入图片描述
此时,worker-thread1会去问问worker-thread0是否需要帮忙,会从队头获取任务进行执行,而worker-thread0是从队尾获取任务进行执行。这就是“工作窃取算法”(工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。)。
在这里插入图片描述

四、源码解析

在这里插入图片描述
使用 ForkJoin 框架,必须首先创建一个 ForkJoin 任务。它提供在任务中执行 fork() 和 join() 操作的机制,通常情况下我们不需要直接继承 ForkJoinTask 类,而只需要继承它的子类,Fork/Join 框架提供了以下两个子类:
RecursiveAction:用于没有返回结果的任务。(比如写数据到磁盘,然后就退出了。 一个RecursiveAction可以把自己的工作分割成更小的几块, 这样它们可以由独立的线程或者CPU执行。 我们可以通过继承来实现一个RecursiveAction)
RecursiveTask :用于有返回结果的任务。(可以将自己的工作分割为若干更小任务,并将这些子任务的执行合并到一个集体结果。 可以有几个水平的分割和合并)

4-1 ForkJoinPool构造函数

private ForkJoinPool(int parallelism,
   ForkJoinWorkerThreadFactory factory,
   UncaughtExceptionHandler handler,
   int mode,
   String workerNamePrefix) {
     this.workerNamePrefix = workerNamePrefix;
     this.factory = factory;
     this.ueh = handler;
     this.config = (parallelism & SMASK) | mode;
     long np = (long)(‐parallelism); // offset ctl counts
     this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}

参数解释:
1>parallelism:并行度( the parallelism level),默认情况下跟我们机器的cpu个数保持一致,使用 Runtime.getRuntime().availableProcessors()可以得到我们机器运行时可用的CPU个数。
2>factory:创建新线程的工厂( the factory for creating new threads)。默认情况下使用
ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory。
3handler:线程异常情况下的处理器(Thread.UncaughtExceptionHandler handler),该处理器在线程执行任务时由于某些无法预料
到的错误而导致任务线程中断时进行一些处理,默认情况为null。
4>asyncMode:这个参数要注意,在ForkJoinPool中,每一个工作线程都有一个独立的任务队列,asyncMode表示工作线程内的任务队列是采用何种方式进行调度,可以是先进先出FIFO,也可以是后进先出LIFO。如果为true,则线程池中的工作线程则使用先进先出方式进行任务调度,默认情况下是false

4-2 ForkJoinTask fork 方法

将任务推入当前工作线程的工作队列中。

 public final ForkJoinTask<V> fork() {
        Thread var1;
        if ((var1 = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
            ((ForkJoinWorkerThread)var1).workQueue.push(this);
        } else {
            ForkJoinPool.common.externalPush(this);
        }

        return this;
    }
package java.util.concurrent;

import java.lang.Thread.UncaughtExceptionHandler;
import java.security.AccessControlContext;
import java.security.CodeSource;
import java.security.PermissionCollection;
import java.security.ProtectionDomain;
import java.util.concurrent.ForkJoinPool.WorkQueue;
import sun.misc.Unsafe;

/**
* 线程池中的每个工作线程(ForkJoinWorkerThread)都有一个自己的任务队列(WorkQueue),工作线程优先处理自身队列中的任务(LIFO或FIFO顺序,由线程池构造时的参数  mode 决定),自身队列为空时,以FIFO的顺序随机窃取其它队列中的任务。
*/
public class ForkJoinWorkerThread extends Thread {
    final ForkJoinPool pool; // 该工作线程归属的线程池
    final WorkQueue workQueue; // 指定的队列
    private static final Unsafe U;
    private static final long THREADLOCALS;
    private static final long INHERITABLETHREADLOCALS;
    private static final long INHERITEDACCESSCONTROLCONTEXT;

    protected ForkJoinWorkerThread(ForkJoinPool var1) {
        super("aForkJoinWorkerThread"); // 指定工作线程名称
        this.pool = var1;
        this.workQueue = var1.registerWorker(this); // 将自己注册到线程池中
    }

    ForkJoinWorkerThread(ForkJoinPool var1, ThreadGroup var2, AccessControlContext var3) {
        super(var2, (Runnable)null, "aForkJoinWorkerThread");
        U.putOrderedObject(this, INHERITEDACCESSCONTROLCONTEXT, var3);
        this.eraseThreadLocals();
        this.pool = var1;
        this.workQueue = var1.registerWorker(this);
    }

    public ForkJoinPool getPool() {
        return this.pool;
    }

    public int getPoolIndex() {
        return this.workQueue.getPoolIndex();
    }

    protected void onStart() {
    }

    protected void onTermination(Throwable var1) {
    }

    public void run() {
        if (this.workQueue.array == null) {
            Throwable var1 = null;

            try {
                // 空方法,待用户自己实现
                this.onStart(); 
                // 执行队列中的task任务
                this.pool.runWorker(this.workQueue);
            } catch (Throwable var40) {
                var1 = var40;
            } finally {
                try {
                    // 空方法,待用户自己实现
                    this.onTermination(var1);
                } catch (Throwable var41) {
                    if (var1 == null) {
                        var1 = var41;
                    }
                } finally {
                    this.pool.deregisterWorker(this, var1);
                }

            }
        }

    }

    final void eraseThreadLocals() {
        U.putObject(this, THREADLOCALS, (Object)null);
        U.putObject(this, INHERITABLETHREADLOCALS, (Object)null);
    }

    void afterTopLevelExec() {
    }

    static {
        try {
            U = Unsafe.getUnsafe();
            Class var0 = Thread.class;
            THREADLOCALS = U.objectFieldOffset(var0.getDeclaredField("threadLocals"));
            INHERITABLETHREADLOCALS = U.objectFieldOffset(var0.getDeclaredField("inheritableThreadLocals"));
            INHERITEDACCESSCONTROLCONTEXT = U.objectFieldOffset(var0.getDeclaredField("inheritedAccessControlContext"));
        } catch (Exception var1) {
            throw new Error(var1);
        }
    }

    static final class InnocuousForkJoinWorkerThread extends ForkJoinWorkerThread {
        private static final ThreadGroup innocuousThreadGroup = createThreadGroup();
        private static final AccessControlContext INNOCUOUS_ACC = new AccessControlContext(new ProtectionDomain[]{new ProtectionDomain((CodeSource)null, (PermissionCollection)null)});

        InnocuousForkJoinWorkerThread(ForkJoinPool var1) {
            super(var1, innocuousThreadGroup, INNOCUOUS_ACC);
        }

        void afterTopLevelExec() {
            this.eraseThreadLocals();
        }

        public ClassLoader getContextClassLoader() {
            return ClassLoader.getSystemClassLoader();
        }

        public void setUncaughtExceptionHandler(UncaughtExceptionHandler var1) {
        }

        public void setContextClassLoader(ClassLoader var1) {
            throw new SecurityException("setContextClassLoader");
        }

        private static ThreadGroup createThreadGroup() {
            try {
                Unsafe var0 = Unsafe.getUnsafe();
                Class var1 = Thread.class;
                Class var2 = ThreadGroup.class;
                long var3 = var0.objectFieldOffset(var1.getDeclaredField("group"));
                long var5 = var0.objectFieldOffset(var2.getDeclaredField("parent"));

                ThreadGroup var8;
                for(ThreadGroup var7 = (ThreadGroup)var0.getObject(Thread.currentThread(), var3); var7 != null; var7 = var8) {
                    var8 = (ThreadGroup)var0.getObject(var7, var5);
                    if (var8 == null) {
                        return new ThreadGroup(var7, "InnocuousForkJoinWorkerThreadGroup");
                    }
                }
            } catch (Exception var9) {
                throw new Error(var9);
            }

            throw new Error("Cannot create ThreadGroup");
        }
    }
}

4-3 ForkJoinTask join 方法

 public final V join() {
        int var1;
        if ((var1 = this.doJoin() & -268435456) != -268435456) {
            this.reportException(var1);
        }
        return this.getRawResult();
    }

 private int doJoin() {
        int var1;
        Thread var2;
        ForkJoinWorkerThread var3;
        WorkQueue var4;
        return (var1 = this.status) < 0 ? var1 : ((var2 = Thread.currentThread()) instanceof ForkJoinWorkerThread ? ((var4 = (var3 = (ForkJoinWorkerThread)var2).workQueue).tryUnpush(this) && (var1 = this.doExec()) < 0 ? var1 : var3.pool.awaitJoin(var4, this, 0L)) : this.externalAwaitDone());
    }

 private void reportException(int var1) {
        if (var1 == -1073741824) {
            throw new CancellationException();
        } else {
            if (var1 == -2147483648) {
                rethrow(this.getThrowableException());
            }

        }
    }

工作流程:
1.检查调用 join() 的线程是否是 ForkJoinThread 线程。如果不是(例如 main 线程),则阻塞当前线程,等待任务完成。如果是,则不阻塞。
2. 查看任务的完成状态,如果已经完成,直接返回结果。
3. 如果任务尚未完成,但处于自己的工作队列内,则完成它。
4. 如果任务已经被其他的工作线程偷走,则窃取这个小偷的工作队列内的任务(以 FIFO 方式),执行,以期帮助它早日完成欲 join 的任务。
5. 如果偷走任务的小偷也已经把自己的任务全部做完,正在等待需要 join 的任务时,则找到小偷的小偷,帮助它完成它的任务。
6. 递归地执行第5步。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值