Java7 ForkJoinPool 的使用以及原理

在JDK7中新增了ForkJoinPool。ForkJoinPool采用分治+work-stealing的思想。可以让我们很方便地将一个大任务拆散成小任务并行地执行提高CPU的使用率

ForkJoinPool & ForkJoinTask 概述:

  • ForkJoinTask:我们要使用 ForkJoin 框架,必须首先创建一个 ForkJoin 任务。它提供在任务中执行 fork() 和 join() 操作的机制,通常情况下我们不需要直接继承 ForkJoinTask 类,而只需要继承它的子类,ForkJoin 框架提供了以下两个子类:
    • RecursiveAction:用于没有返回结果的任务。
    • RecursiveTask :用于有返回结果的任务。
  • ForkJoinPool :ForkJoinTask 需要通过 ForkJoinPool 来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。

如何充分利用多核CPU,计算很大数组中所有整数的和?

剖析

  • 单线程相加?

    我们最容易想到就是单线程相加,一个for循环搞定。

  • 线程池相加?

    如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。

  • 其它?

    Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^

三种实现

OK,剖析完了,我们直接来看三种实现,不墨迹,直接上菜。

/**

 * 计算1亿个整数的和

 */

public class ForkJoinPoolTest01
{
    public static void main(String[] args) throws ExecutionException, InterruptedException {

// 构造数据

        int length = 100000000;
        long[] arr = new long[length];
        for (int i = 0; i < length; i++) {
            arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
        }
// 单线程
        singleThreadSum(arr);
// ThreadPoolExecutor线程池
        multiThreadSum(arr);
// ForkJoinPool线程池
        forkJoinSum(arr);
    }

    private static void singleThreadSum(long[] arr) {
        long start = System.currentTimeMillis();
        long sum = 0;
        for (int i = 0; i < arr.length; i++) {
// 模拟耗时
            sum += (arr[i]/ 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3);
        }
        System.out.println("sum: " + sum);
        System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));
    }

    private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
        long start = System.currentTimeMillis();
        int count = 8;
        ExecutorService threadPool = Executors.newFixedThreadPool(count);
        List<Future<Long>> list = new ArrayList<>();
        for(int i = 0; i < count; i++) {
            int num = i;
// 分段提交任务
            Future<Long> future = threadPool.submit(() -> {
                long sum = 0;
                for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
                    try {
// 模拟耗时
                        sum += (arr[j]/ 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3);
                    }catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                return sum;
            });
            list.add(future);
        }
// 每个段结果相加
        long sum = 0;
        for(Future<Long> future : list) {
            sum += future.get();
        }
        System.out.println("sum: " + sum);
        System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
    }



    private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
        long start = System.currentTimeMillis();
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
// 提交任务
        ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(
                new SumTask(arr, 0, arr.length));
// 获取结果
        Long sum = forkJoinTask.get();
        forkJoinPool.shutdown();
        System.out.println("sum: " + sum);
        System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
    }

    private static class SumTask extends
            RecursiveTask<Long> {
        private long[] arr;
        private int from;
        private int to;
        
        public SumTask(long[] arr, int from, int to) {
            this.arr = arr;
            this.from = from;
            this.to = to;
        }
        
        @Override
        protected Long compute() {
// 小于1000的时候直接相加,可灵活调整
            if (to - from <= 1000) {
                long sum = 0;
                for (int i = from; i < to; i++) {
// 模拟耗时
                    sum += (arr[i]/ 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3);

                }
                return sum;
            }
            
// 分成两段任务

            int middle = (from + to) / 2;
            SumTask left = new SumTask(arr, from, middle);
            SumTask right = new SumTask(arr, middle, to);
// 提交左边的任务
            left.fork();
// 右边的任务直接利用当前线程计算,节约开销
            Long rightResult = right.compute();
// 等待左边计算完毕
            Long leftResult = left.join();
// 返回结果
            return
                    leftResult + rightResult;
        }
    }
}

如果不加“都 /3*3/3*3/3*3/3*3/3*3了一顿操作” ,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。

所以,为了演示ForkJoinPool的牛逼之处,把每个数都 /3*3/3*3/3*3/3*3/3*3了一顿操作,用来模拟计算耗时。

来看结果:

sum: 107352457433800662

single thread elapse: 789

sum: 107352457433800662

multi thread elapse: 228

sum: 107352457433800662

fork join elapse: 189

 

可以看到,ForkJoinPool相对普通线程池还是有很大提升的。

分治法

  • 基本思想

    把一个规模大的问题划分为规模较小的子问题,然后分而治之,最后合并子问题的解得到原问题的解。

  • 步骤

    (1)分割原问题:

    (2)求解子问题:

    (3)合并子问题的解为原问题的解。

在分治法中,子问题一般是相互独立的,因此,经常通过递归调用算法来求解子问题。

  • 典型应用场景

    (1)二分搜索

    (2)大整数乘法

    (3)Strassen矩阵乘法

    (4)棋盘覆盖

    (5)归并排序

    (6)快速排序

    (7)线性时间选择

    (8)汉诺塔

ForkJoinPool继承体系

ForkJoinPool是 java 7 中新增的线程池类,它的继承体系如下:

ForkJoinPool和ThreadPoolExecutor都是继承自AbstractExecutorService抽象类,所以它和ThreadPoolExecutor的使用几乎没有多少区别,除了任务变成了ForkJoinTask以外。

这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。

可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码造成干扰,还能利用原有的特性。

ForkJoinTask

两个主要方法

  • fork()

    fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。

  • join()

    join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

三个子类

  • RecursiveAction

    无返回值任务。

  • RecursiveTask

    有返回值任务。

  • CountedCompleter

    无返回值任务,完成任务后可以触发回调。

ForkJoinPool内部原理

ForkJoinPool内部使用的是“工作窃取”算法实现的。

(1)每个工作线程都有自己的工作队列WorkQueue;

(2)这是一个双端队列,它是线程私有的;

(3)ForkJoinTask中fork的子任务,将放入运行该任务的工作线程的队头,工作线程将以LIFO的顺序来处理工作队列中的任务;

(4)为了最大化地利用CPU,空闲的线程将从其它线程的队列中“窃取”任务来执行;

(5)从工作队列的尾部窃取任务,以减少竞争;

(6)双端队列的操作:push()/pop()仅在其所有者工作线程中调用,poll()是由其它线程窃取任务时调用的;

(7)当只剩下最后一个任务时,还是会存在竞争,是通过CAS来实现的;

ForkJoinTask fork 方法

fork() 做的工作只有一件事,既是把任务推入当前工作线程的工作队列里。可以参看以下的源代码:

    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

 

ForkJoinTask join 方法

join() 的工作则复杂得多,也是 join() 可以使得线程免于被阻塞的原因——不像同名的 Thread.join()

  1. 检查调用 join() 的线程是否是 ForkJoinThread 线程。如果不是(例如 main 线程),则阻塞当前线程,等待任务完成。如果是,则不阻塞。
  2. 查看任务的完成状态,如果已经完成,直接返回结果。
  3. 如果任务尚未完成,但处于自己的工作队列内,则完成它。
  4. 如果任务已经被其他的工作线程偷走,则窃取这个小偷的工作队列内的任务(以 FIFO 方式),执行,以期帮助它早日完成欲 join 的任务。
  5. 如果偷走任务的小偷也已经把自己的任务全部做完,正在等待需要 join 的任务时,则找到小偷的小偷,帮助它完成它的任务。
  6. 递归地执行第5步。

将上述流程画成序列图的话就是这个样子:

 

ForkJoinPool.submit 方法

    public static void main(String[] args) throws InterruptedException {
        // 创建包含Runtime.getRuntime().availableProcessors()返回值作为个数的并行线程的ForkJoinPool
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        // 提交可分解的PrintTask任务
        forkJoinPool.submit(new MyRecursiveAction(0, 1000));

        while (!forkJoinPool.isTerminated()) {
            forkJoinPool.awaitTermination(2, TimeUnit.SECONDS);
        }
        // 关闭线程池
        forkJoinPool.shutdown();
    }

其实除了前面介绍过的每个工作线程自己拥有的工作队列以外,ForkJoinPool 自身也拥有工作队列,这些工作队列的作用是用来接收由外部线程(非 ForkJoinThread 线程)提交过来的任务,而这些工作队列被称为 submitting queue 。

submit() 和 fork() 其实没有本质区别,只是提交对象变成了 submitting queue 而已(还有一些同步,初始化的操作)。submitting queue 和其他 work queue 一样,是工作线程”窃取“的对象,因此当其中的任务被一个工作线程成功窃取时,就意味着提交的任务真正开始进入执行阶段。

 

ForkJoinPool最佳实践

(1)最适合的是计算密集型任务;

(2)在需要阻塞工作线程时,可以使用ManagedBlocker;

(3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();

总结

(1)ForkJoinPool特别适合于“分而治之”算法的实现;

(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;

(4)ForkjoinPool内部基于“工作窃取”算法实现;

(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

(6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;

(7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

 

使用场景补充:

 遍历系统所有文件,得到系统中文件的总数。

思路

通过递归的方法。任务在遍历中如果发现文件夹就创建新的任务让线程池执行,将返回的文件数加起来,如果发现文件则将计数加一,最终将该文件夹下的文件数返回。

代码实现

 

    CountingTask countingTask = new CountingTask(Environment.getExternalStorageDirectory());
    forkJoinPool.invoke(countingTask);

    class CountingTask extends RecursiveTask<Integer> {
        private File dir;

        public CountingTask(File dir) {
            this.dir = dir;
        }

        @Override
        protected Integer compute() {
            int count = 0;

            File files[] = dir.listFiles();
            if(files != null){
                for (File f : files){
                    if(f.isDirectory()){
                        // 对每个子目录都新建一个子任务。
                        CountingTask countingTask = new CountingTask(f);
                        countingTask.fork();
                        count += countingTask.join();

                    }else {
                        Log.d("tag" , "current path = "+f.getAbsolutePath());
                        count++;
                    }
                }
            }


            return count;
        }
    }         

上面的需求,如果我们用普通的线程池该如何完成?

如果我们使用newFixedThreadPool,当核心线程的路径下都有子文件夹时,它们会将路径下的子文件夹抛给任务队列,最终变成所有的核心线程都在等待子文件夹的返回结果,从而造成死锁。最终任务无法完成。

如果我们使用newCachedThreadPool,依然用上面的思路可以完成任务。但是每次子文件夹就会创建一个新的工作线程,这样消耗过大。

因此,在这样的情况下,ForkJoinPool的work-stealing的方式就体现出了优势。每个任务分配的子任务也由自己执行,只有自己的任务执行完成时,才会去执行别的工作线程的任务。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值