前言
我们在开发中经常会涉及到线程的使用,特别是在一些高并发的场景中,如果只是以单线程去执行就会特别耗费时间。今天就来认识一下在开发中可能会使用到的ForkJoinPool类,作者本人也是最近看线程池代码时才发现了一个ForkJoinPool类。
一、ForkJoinPool是什么?
ForkJoinPool又叫分而治之,通俗来讲就是帮我们把一个任务分成许多小任务给不同的线程执行,然后通过join将多个线程处理的结果进行汇总返回。我们可以先看一下线程池的继承关系,顶层的Executor接口,提供了一个execute()方法,我们常用的ExecutorService接口也继承自Executor接口,其中定义了一些额外的方法,下面呢就是AbstractExecutorService类,在这个类中实现了三个submit方法,而创建线程池的ThreadPoolExecutor类就继承自AbstractExecutorService,而同时ForkJoinPool也继承了AbstractExecutorService,具体关系可以看下图。
二、如何使用
我们在提交任务时,一般不会直接继承ForkJoinTask,只要继承它的子类即可。两者都提供了抽象方法compute,我们可以重写compute方法进行任务的划分。
-
RecursiveAction:用于没有返回结果的任务(类似Runnable)
-
RecursiveTask:用于有返回结果的任务(类似Callable)
此外,ForkJoinPool采取工作窃取算法,以避免工作线程由于拆分了任务之后的join等待过程。这样处于空闲的工作线程将从其他工作线程的队列中主动去窃取任务来执行。那么什么是工作窃取呢?
工作窃取是指当某个线程的任务队列中没有可执行任务的时候,从其他线程的任务队列中窃取任务来执行,以充分利用工作线程的计算能力,减少线程由于获取不到任务而造成的空闲浪费。在ForkJoinpool中,工作任务的队列都采用双端队列Deque容器。我们知道,在通常使用队列的过程中,我们都在队尾插入,而在队头消费以实现FIFO。而为了实现工作窃取。一般我们会改成工作线程在工作队列上LIFO,而窃取其他线程的任务的时候,从队列头部取获取。
三、使用案例
我们先来看看ForkJoinPool类中的构造函数。
ForkJoinPool()
ForkJoinPool(int parallelism)
ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode)
-
parallelism:由几个线程来拆分任务,如果不填则更具CPU核数创建线程数
-
factory:创建工作线程的工厂实现
-
handler:线程因未知异常而终止的回调处理
-
asyncMode:是否异步,默认false
1.提交有返回值的任务
代码如下(示例):
package com.yd.data.mrs.core.task.service;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
import java.util.stream.IntStream;
/**
* @author hjh
* @description 提交有返回值的任务
* @since 2021-07-30 15:42
*/
public class ForkJoinRecursiveTask {
/**
* 最大计算数
*/
private static final int MAX_THRESHOLD = 100;
public static void main(String[] args) {
//创建ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();
//异步提交RecursiveTask任务
ForkJoinTask<Integer> forkJoinTask = pool.submit(new CalculatedRecursiveTask(0, 1000));
try {
//根据返回类型获取返回值
Integer result = forkJoinTask.get();
System.out.println("结果为:" + result);
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
} finally {
pool.shutdown();
}
}
private static class CalculatedRecursiveTask extends RecursiveTask<Integer> {
private final int start;
private final int end;
public CalculatedRecursiveTask(int start, int end) {
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
//判断计算范围,如果小于等于5,那么一个线程计算就够了,否则进行分割
if ((end - start) <= MAX_THRESHOLD) {
//返回[start,end]的总和
return IntStream.rangeClosed(start, end).sum();
} else {
//任务分割
int middle = (end + start) / 2;
CalculatedRecursiveTask task1 = new CalculatedRecursiveTask(start, middle);
CalculatedRecursiveTask task2 = new CalculatedRecursiveTask(middle + 1, end);
//执行
task1.fork();
task2.fork();
//等待返回结果
return task1.join() + task2.join();
}
}
}
}
2.提交无返回值的任务
代码如下(示例):
package com.yd.data.mrs.core.task.service;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
/**
* @author hjh
* @description 提交无返回值的任务
* @since 2021-07-30 15:53
*/
public class ForkJoinRecursiveAction {
/**
* 最大计算数
*/
private static final int MAX_THRESHOLD = 100;
private static final AtomicInteger SUM = new AtomicInteger(0);
public static void main(String[] args) throws InterruptedException {
//创建ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();
//异步提交RecursiveAction任务
pool.submit(new CalculatedRecursiveTask(0, 1000));
//等待3秒后输出结果,因为计算需要时间
pool.awaitTermination(1, TimeUnit.SECONDS);
System.out.println("结果为:" + SUM);
pool.shutdown();
}
private static class CalculatedRecursiveTask extends RecursiveAction {
private final int start;
private final int end;
public CalculatedRecursiveTask(int start, int end) {
this.start = start;
this.end = end;
}
@Override
protected void compute() {
//判断计算范围,如果小于等于5,那么一个线程计算就够了,否则进行分割
if ((end - start) <= MAX_THRESHOLD) {
//因为没有返回值,所有这里如果我们要获取结果,需要存入公共的变量中
SUM.addAndGet(IntStream.rangeClosed(start, end).sum());
} else {
//任务分割
int middle = (end + start) / 2;
CalculatedRecursiveTask task1 = new CalculatedRecursiveTask(start, middle);
CalculatedRecursiveTask task2 = new CalculatedRecursiveTask(middle + 1, end);
//执行
task1.fork();
task2.fork();
}
}
}
}
虽然ForkJoin实际的代码非常复杂,但是通过这个例子我们应该了解到ForkJoinPool底层的分治算法和工作窃取原理。ForkJoin不仅在java8之后的stream中广泛使用。golang等其他语言的协程机制,也是采用类似的原理来实现的。
欢迎大家添加个人公众号,一起进步努力。