Fork/Join
基本介绍
Fork/Join
是 JDK 1.7 加入的新的线程池实现,它体现的是一种分治思想,适用于能够进行任务拆分的 cpu 密集型运算
所谓的任务拆分,是将一个大任务拆分为算法上相同的小任务,直至不能拆分可以直接求解。跟递归相关的一些计算,如归并排序、斐波那契数列、都可以用分治思想进行求解
Fork/Join
在分治的基础上加入了多线程,可以把每个任务的分解和合并交给不同的线程来完成,进一步提升了运算效率
Fork/Join
默认会创建与 cpu 核心数大小相同的线程池
基本使用
定义Task对象
提交给 Fork/Join 线程池的任务需要继承 RecursiveTask
(有返回值)或 RecursiveAction
(没有返回值)
小案例 : 计算对 1~n
之间的整数求和的任务 :
- 我们使用递归来实现
package cn.knightzz.pool.fork_join;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.RecursiveTask;
/**
* @author 王天赐
* @title: AddTask
* @projectName hm-juc-codes
* @description: Task对象
* @website <a href="http://knightzz.cn/">http://knightzz.cn/</a>
* @github <a href="https://github.com/knightzz1998">https://github.com/knightzz1998</a>
* @create: 2022-09-03 18:44
*/
@SuppressWarnings("all")
@Slf4j(topic = "c.AddTask")
public class AddTask extends RecursiveTask<Integer> {
int number;
public AddTask(int number) {
this.number = number;
}
@Override
public String toString() {
return "{" +
"number=" + number +
'}';
}
@Override
protected Integer compute() {
// 如果number==1, 直接计算结果
if (number == 1) {
log.debug("join() {}" , number);
return number;
}
// 将任务拆分
AddTask task = new AddTask(number - 1);
// 安排在当前任务正在运行的池中异步执行此任务(如果适用)
task.fork();
log.debug("fork() {} + {}", number, task);
// 合并结果
int result = number + task.join();
log.debug("join() {} + {} = {}", number, task, result);
return result;
}
}
提交给 ForkJoinPool
package cn.knightzz.pool.fork_join;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.ForkJoinPool;
/**
* @author 王天赐
* @title: ForkJoinPoolTest
* @projectName hm-juc-codes
* @description:
* @website <a href="http://knightzz.cn/">http://knightzz.cn/</a>
* @github <a href="https://github.com/knightzz1998">https://github.com/knightzz1998</a>
* @create: 2022-09-04 16:22
*/
@SuppressWarnings("all")
@Slf4j(topic = "c.ForkJoinPoolTest")
public class ForkJoinPoolTest {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool(4);
Integer result = pool.invoke(new AddTask(5));
log.debug("result : {}", result);
}
}
结果如下所示
16:26:42.217 [ForkJoinPool-1-worker-2] DEBUG c.AddTask - fork() 4 + {number=3}
16:26:42.217 [ForkJoinPool-1-worker-0] DEBUG c.AddTask - fork() 2 + {number=1}
16:26:42.217 [ForkJoinPool-1-worker-3] DEBUG c.AddTask - fork() 3 + {number=2}
16:26:42.217 [ForkJoinPool-1-worker-1] DEBUG c.AddTask - fork() 5 + {number=4}
16:26:42.221 [ForkJoinPool-1-worker-0] DEBUG c.AddTask - join() 1
16:26:42.222 [ForkJoinPool-1-worker-0] DEBUG c.AddTask - join() 2 + {number=1} = 3
16:26:42.222 [ForkJoinPool-1-worker-3] DEBUG c.AddTask - join() 3 + {number=2} = 6
16:26:42.222 [ForkJoinPool-1-worker-2] DEBUG c.AddTask - join() 4 + {number=3} = 10
16:26:42.222 [ForkJoinPool-1-worker-1] DEBUG c.AddTask - join() 5 + {number=4} = 15
16:26:42.222 [main] DEBUG c.ForkJoinPoolTest - result : 15
Process finished with exit code 0
优化ForkJoin
我们可以在计算的时候中间再次拆分任务 :
![image-20220904164131079](https://i-blog.csdnimg.cn/blog_migrate/498a915c40fd353cc650f47fc303a9e7.png)
代码如下 :
package cn.knightzz.pool.fork_join;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.analysis.function.Add;
import java.util.concurrent.RecursiveTask;
/**
* @author 王天赐
* @title: AddTask
* @projectName hm-juc-codes
* @description: Task对象
* @website <a href="http://knightzz.cn/">http://knightzz.cn/</a>
* @github <a href="https://github.com/knightzz1998">https://github.com/knightzz1998</a>
* @create: 2022-09-03 18:44
*/
@SuppressWarnings("all")
@Slf4j(topic = "c.AddTaskOptimization")
public class AddTaskOptimization extends RecursiveTask<Integer> {
int begin;
int end;
public AddTaskOptimization(int begin, int end) {
this.begin = begin;
this.end = end;
}
@Override
public String toString() {
return "{" +
"begin=" + begin +
", end=" + end +
'}';
}
@Override
protected Integer compute() {
// 当 begin == end 时, 无法再拆分了
if (begin == end) {
log.debug("join() {}", begin);
return begin;
}
if (end - begin == 1) {
log.debug("join() {} + {} = {}", begin, end, end + begin);
return end + begin;
}
// 1 5
int mid = (end + begin) / 2; // 3
AddTaskOptimization t1 = new AddTaskOptimization(begin, mid); // 1,3
t1.fork();
AddTaskOptimization t2 = new AddTaskOptimization(mid + 1, end); // 4,5
t2.fork();
log.debug("fork() {} + {} = ?", t1, t2);
int result = t1.join() + t2.join();
log.debug("join() {} + {} = {}", t1, t2, result);
return result;
}
}
测试代码 :
package cn.knightzz.pool.fork_join;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.ForkJoinPool;
/**
* @author 王天赐
* @title: ForkJoinPoolTest
* @projectName hm-juc-codes
* @description:
* @website <a href="http://knightzz.cn/">http://knightzz.cn/</a>
* @github <a href="https://github.com/knightzz1998">https://github.com/knightzz1998</a>
* @create: 2022-09-04 16:22
*/
@SuppressWarnings("all")
@Slf4j(topic = "c.ForkJoinPoolTest")
public class ForkJoinPoolTest {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool(4);
Integer result = pool.invoke(new AddTaskOptimization(1, 5));
log.debug("result : {}", result);
}
}
运行结果 :
17:04:56.451 [ForkJoinPool-1-worker-1] DEBUG c.AddTaskOptimization - fork() {begin=1, end=3} + {begin=4, end=5} = ?
17:04:56.451 [ForkJoinPool-1-worker-0] DEBUG c.AddTaskOptimization - join() 1 + 2 = 3
17:04:56.451 [ForkJoinPool-1-worker-2] DEBUG c.AddTaskOptimization - fork() {begin=1, end=2} + {begin=3, end=3} = ?
17:04:56.451 [ForkJoinPool-1-worker-3] DEBUG c.AddTaskOptimization - join() 4 + 5 = 9
17:04:56.454 [ForkJoinPool-1-worker-1] DEBUG c.AddTaskOptimization - join() 3
17:04:56.454 [ForkJoinPool-1-worker-2] DEBUG c.AddTaskOptimization - join() {begin=1, end=2} + {begin=3, end=3} = 6
17:04:56.454 [ForkJoinPool-1-worker-1] DEBUG c.AddTaskOptimization - join() {begin=1, end=3} + {begin=4, end=5} = 15
17:04:56.454 [main] DEBUG c.ForkJoinPoolTest - result : 15