Fork/Join 框架 有时也称为 分解/合并框架,Fork/Join 框架采用分而治之将问题拆分成小问题。在一个任务中,先检查要解决问题的大小,如果大于设定,那就将问题拆分成可以通过框架来执行的小任务,如果问题的大小比设定的大小要小就直接在任务里解决这个问题,然后根据需要返回结果
Fork/Join 框架基于以下两种操作 :
分解(Fork) : 当需要讲一个任务拆分成更小的多个任务时,在框架中执行这些任务
合并(Join) : 当一个主任务等待其创建的多个子任务的完成执行
Fork/Join 框架执行任务的限制 :
1> 任务只能使用 fork() 和 join() 操作当前同步机制,如果使用其它同步机制,工作者线程就不能执行其他任务,当然这些任务是在同步操作里时。比如,如果在 Fork/Join 框架 中将一个任务休眠,正在执行这个任务的工作者线程在休眠期内不能执行另外一个任务
2> 任务不能执行 I/O操作,比如文件数据的读取与写入
3> 任务不能抛出非运行时异常,必须在代码中处理掉这些异常
Fork/Join 框架两个核心组成类
1> ForkJoinPool : 该类实现 ExecutorService接口和工作窃取算法,它管理工作者线程,并提供任务的状态信息,以及任务的执行信息
2> ForkJoinTask : 是在 ForkJoinPool 中执行任务的基类
在实现 Fork/Join 框架任务,通常需要实现以下两个类之一的子类
1> RecursiveAction : 用于没有返回结果的场景,继承 ForkJoinTask 类
2> RecursiveTask : 用于任务有返回结果的场景,继承 ForkJoinTask 类
在 ForkJoinPool 中执行 ForkJoinTask 时,可以采用同步或者异步。采用同步方法执行时,发送任务给 Fork/Join 线程池的方法直接到任务执行完成后才会返回结果。而采用异步方法执行时,发送任务给执行器的方法将立即返回结果,但任务仍能继续执行
ForkJoinPool 类
该类看作成一个特殊的 Executor执行器类型,该类实现 ExecutorService接口和工作窃取算法
public ForkJoinPool() : 使用无参构造器创建对象,将使用默认配置,创建的线程数等于计算机CPU数目的线程池,创建好 ForkJoinPool对象 之后线程也就准备好了,在线程池中等待任务的到达,然后开始执行
public void execute(ForkJoinTask<?> task) : 安排提供异步执行任务,只要一设置就开始执行这个任务
public void execute(Runnable task) : 使用 Runnable 对象执行任务时,就不采用窃取算法执行任务
public <T> T invoke(ForkJoinTask<T> task) : execute方法 时异步调用,而 invoke方法 则是同步调用直到传递进来的任务执行结束后才会返回
public int getActiveThreadCount() : 返回一个预估正在窃取或执行任务的线程数量,这个方法也许估计高于活跃线程数
public long getStealCount() : 从另外一个线程工作队列中返回一个估计全部窃取任务数量。当线程池不沉寂时,返回的被低估实际窃取总数的值。这个值也许对监控有用而且协调Fork/Join 框架程序。通常,窃取数量应该足够高来保持线程繁忙,但足够低以避免线程的开销和争用
public int getParallelism() : 返回此池的目标并行级别
public long getQueuedTaskCount() : 返回工作线程当前在队列中执行的任务总数(但不包括提交到尚未执行的池的任务)。这个值只是一个近似,通过遍历池中的所有线程获得。该方法可用于优化任务粒度
public void shutdown() : 可能会初始化一个有序开关,在此之前提交的任务将被执行,但不会有新任务被接受。如果这是 commonPool(),调用对执行状态没有影响,如果已经关闭,则没有额外的影响。在此方法过程中同时提交的任务可能或不可能被拒绝
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException : 在 shutdown() 方法执行后阻塞线程直到所有任务完成执行或者当前线程被阻塞,又或是到达指定时间。当申请公共池时,直到程序停止 commonPool() 才会终止,这个方法相当于 awaitQuiescence 方法但总是返回 false
ForkJoinTask 抽象类
public final boolean isDone() : 如果任务完成,返回 true
public final boolean isCompletedNormally() : 如果完成任务没有抛出异常且没有取消,返回
public final V get() throws InterruptedException, ExecutionException : 如果有必要等待计算完成,然后获取结果,这个方法返回 RecursiveTask 的 compute() 计算结果
public final V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException : 如果有必要等待到在最大给定的时间内完成计算,然后如果可以则获取结果。如果超时计算还未完成,那就返回 null 值
public static void invokeAll(ForkJoinTask<?>... tasks)
public static void invokeAll(ForkJoinTask<?> t1, ForkJoinTask<?> t2) : invokeAll方法执行一个主任务所创建的多个子任务,这是一个同步调用,这个任务将等待子任务完成然后继续执行也可能结束。当一个主任务等待它的子任务时,执行这个主任务的工作者线程接收另一个等待执行的任务并开始执行
分叉指定的任务,当 isDone方法 保存每个任务或遇到一个 (未检查的)异常时返回,在这种情况下,异常被重新抛出。如果超过一个任务遇到异常,此时这个方法抛出这些异常中任意一个。如果任何任务遇到异常,其他任务可能被取消。可是,每个任务的执行状态在异常返回时无法保证。可以使用 getException() 和相关方法获得每个任务的状态,检查它们是否被取消、正常完成或异常,或未处理
public static <T extends ForkJoinTask<?>> Collection<T> invokeAll(Collection<T> tasks) : 在特定集合中分开所有任务,当 isDone方法 保存每个任务或遇到一个 (未检查的)异常时返回,在这种情况下,异常被重新抛出。如果超过一个任务遇到异常,此时这个方法抛出这些异常中任意一个。如果任何任务遇到异常,其他任务可能被取消。可是,每个任务的执行状态在异常返回时无法保证。可以使用 getException() 和相关方法获得每个任务的状态,检查它们是否被取消、正常完成或异常,或未处理
public final ForkJoinTask<V> fork() : 在当前任务的池中,如果可获取,则在当前任务中执行该任务,或者使用 commonPool()方法如果不能获取 。虽然它不是强制执行的,但如果它已经完成并被重新初始化,那么它将不止一次地使用一个任务。随后对该任务状态的修改或其操作的任何数据,在执行它之前的任何线程中都不一定能始终观察到,除非在调用 join方法 或相关方法之前,或者调用 isDone() 返回 true
说白了就是若有空闲工作者线程或者创建新线程来执行任务,执行任务不会阻塞当前线程会直接返回 ForkJoinTask 来存放执行结果,此时住线程可运行其它任务
public final V join() : 当 isDone() 方法完成时返回计算结果,否则会一直等待计算完成。这个方法不同于 get() 方法 因为 在RuntimeException 非正常完成结果 或 没有ExecutionException 的 Error,此时 调用线程的中断不会因为抛出 InterruptedException 而导致方法突然返回
join() 方法不能被中断,若中断调用 join()方法的线程,方法将会抛出 InterruptedException
public final Throwable getException() : 返回在计算过程中抛出的异常,或者如果任务取消返回 CancellationException,或者如果 没有异常或者任务没有完成返回 null
public final boolean isCompletedNormally() : 如果任务完成没有抛出异常且没有取消,返回 true
public final boolean isCompletedAbnormally() : 如果任务抛出异常或取消,返回 true
public boolean cancel(boolean mayInterruptIfRunning) : 尝试取消执行的这个任务。如果任务已经完成或者某些原因不能取消尝试将会失败。如果取消任务成功而且当取消任务时这个任务还未开始,这个执行的任务将会停止。在这个方法返回成功后,除非调用 reinitialize() 随后调用 isCancelled() isDone() 且返回 true 而且调用 join 相关方法结果会是 CancellationException。这个方法可能在子类中被覆盖,但如果是这样的话,仍然必须确保这些属性保持不变。特别是,cancel方法本身不能抛出异常。这个方法设计的目的就是被其他任务调用。对于计算方法,为终止当前任务,你可以返回或者抛出 unchecked exception 或者 调用 completeExceptionally(Throwable) 方法
简单说,取消一个没有执行的任务,如果已经开始执行则无法取消,参数 mayInterruptIfRunning 为 true 表示这个任务即是在运行也将被取消
Fork/Join 框架的局限在于 ForkJoinPool线程中的任务不允许被取消
RecursiveAction 将任务分解
public class
Product {
private
String
name
;
private double
price
;
public
String getName() {
return
name
;
}
public void
setName(String name) {
this
.
name
= name;
}
public double
getPrice() {
return
price
;
}
public void
setPrice(
double
price) {
this
.
price
= price;
}
}
public class
Task
extends
RecursiveAction {
private static final long
serialVersionUID
= -
3009073490385728376L
;
private
List<Product>
products
;
private int
first
;
private int
last
;
private double
increment
;
public
Task(List<Product> products,
int
first,
int
last,
double
increment) {
this
.
products
= products;
this
.
first
= first;
this
.
last
= last;
this
.
increment
= increment;
}
@Override
protected void
compute() {
if
(
last
-
first
<
10
) {
updatePrices();
}
else
{
int
middle = (
last
+
first
) /
2
;
System.
out
.printf(
"Task: Pending tasks:%s
\n
"
,
getQueuedTaskCount
());
Task t1 =
new
Task(
products
,
first
, middle +
1
,
increment
);
Task t2 =
new
Task(
products
, middle +
1
,
last
,
increment
);
invokeAll
(t1, t2);
}
}
private void
updatePrices() {
for
(
int
i =
first
; i <
last
; i++) {
Product product =
products
.get(i);
product.setPrice(product.getPrice() * (i +
increment
));
}
}
public static
List<Product> generate(
int
size) {
List<Product> ret =
new
ArrayList<>();
for
(
int
i =
0
; i < size; i++) {
Product product =
new
Product();
product.setName(
"Product "
+ i);
product.setPrice(
10
);
ret.add(product);
}
return
ret;
}
public static void
main(String[] args) {
List<Product> products = generate(
10000
);
Task task =
new
Task(products,
0
, products.size(),
0.2
);
ForkJoinPool forkJoinPool =
new
ForkJoinPool();
forkJoinPool.execute(task);
try
{
do
{
System.
out
.printf(
"Main: Thread Count: %d
\n
"
, forkJoinPool.getActiveThreadCount());
System.
out
.printf(
"Main: Thread Steal: %d
\n
"
, forkJoinPool.getStealCount());
System.
out
.printf(
"Main: Thread Parallelism: %d
\n
"
, forkJoinPool.getParallelism());
TimeUnit.
MILLISECONDS
.sleep(
5
);
}
while
(!task.isDone());
}
catch
(InterruptedException e) {
e.printStackTrace();
}
forkJoinPool.shutdown();
if
(task.isCompletedNormally()) {
System.
out
.printf(
"Main: The process has completed normally.
\n
"
);
}
for
(
int
i =
0
; i < products.size(); i++) {
Product product = products.get(i);
if
(product.getPrice() !=
12
) {
System.
out
.printf(
"Produt %s: %f
\n
"
, product.getName(), product.getPrice());
}
}
System.
out
.printf(
"Main: End of the program.
\n
"
);
}
}
合并任务结果
public class
Document {
private
String
words
[] = {
"the"
,
"hello"
,
"goodbye"
,
"packt"
,
"java"
,
"thread"
,
"pool"
,
"random"
,
"class"
,
"main"
};
public
String[][] generateDocument(
int
numLines,
int
numWords, String word) {
int
counter =
0
;
String document[][] =
new
String[numLines][numWords];
Random random =
new
Random();
for
(
int
i =
0
; i < numLines; i++) {
for
(
int
j =
0
; j < numWords; j++) {
int
index = random.nextInt(
words
.
length
);
document[i][j] =
words
[index];
if
(document[i][j].equals(word)) {
counter++;
}
}
}
System.
out
.println(
"DocumentMock: The word appears "
+ counter +
" times in the document"
);
return
document;
}
}
public class
LineTask
extends
RecursiveTask<Integer> {
private static final long
serialVersionUID
=
8771633587730430578L
;
private
String
line
[];
private int
start
,
end
;
private
String
word
;
public
LineTask(String[] line,
int
start,
int
end, String word) {
this
.
line
= line;
this
.
start
= start;
this
.
end
= end;
this
.
word
= word;
}
@Override
protected
Integer compute() {
Integer result =
null
;
if
(
end
-
start
<
100
) {
result = count(
line
,
start
,
end
,
word
);
}
else
{
int
mid = (
start
+
end
) /
2
;
LineTask task1 =
new
LineTask(
line
,
start
, mid,
word
);
LineTask task2 =
new
LineTask(
line
, mid,
end
,
word
);
invokeAll
(task1, task2);
try
{
result = groupResults(task1.get(), task2.get());
}
catch
(InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
return
result;
}
private
Integer count(String[] line,
int
start,
int
end, String word) {
int
counter =
0
;
for
(
int
i = start; i < end; i++) {
if
(line[i].equals(word)) {
counter++;
}
}
try
{
Thread.
sleep
(
10
);
}
catch
(InterruptedException e) {
e.printStackTrace();
}
return
counter;
}
private
Integer groupResults(Integer number1, Integer number2) {
return
number1 + number2;
}
}
public class
DocumentTask
extends
RecursiveTask<Integer> {
private
String
document
[][];
private int
start
,
end
;
private
String
word
;
public
DocumentTask(String[][] document,
int
start,
int
end, String word) {
this
.
document
= document;
this
.
start
= start;
this
.
end
= end;
this
.
word
= word;
}
@Override
protected
Integer compute() {
int
result =
0
;
if
(
end
-
start
<
10
) {
result = processLines(
document
,
start
,
end
,
word
);
}
else
{
int
mid = (
start
+
end
) /
2
;
DocumentTask task1 =
new
DocumentTask(
document
,
start
, mid,
word
);
DocumentTask task2 =
new
DocumentTask(
document
, mid,
end
,
word
);
invokeAll
(task1, task2);
try
{
result = groupResults(task1.get(), task2.get());
}
catch
(InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
return
result;
}
private
Integer processLines(String[][] document,
int
start,
int
end, String word) {
List<LineTask> tasks =
new
ArrayList<>();
for
(
int
i = start; i < end; i++) {
LineTask lineTask =
new
LineTask(document[i],
0
, document[i].
length
, word);
tasks.add(lineTask);
}
invokeAll
(tasks);
int
result =
0
;
try
{
for
(LineTask lineTask : tasks) {
result = result + lineTask.get();
}
}
catch
(InterruptedException | ExecutionException e) {
e.printStackTrace();
}
return
result;
}
private
Integer groupResults(Integer number1, Integer number2) {
return
number1 + number2;
}
public static void
main(String[] args) {
Document mock =
new
Document();
String[][] document = mock.generateDocument(
100
,
1000
,
"the"
);
DocumentTask task =
new
DocumentTask(document,
0
,
100
,
"the"
);
ForkJoinPool pool =
new
ForkJoinPool();
pool.execute(task);
do
{
System.
out
.printf(
"Main: Thread Parallelism: %d
\n
"
, pool.getParallelism());
System.
out
.printf(
"Main: Thread Active Threads: %d
\n
"
, pool.getActiveThreadCount());
System.
out
.printf(
"Main: Thread Task Count: %d
\n
"
, pool.getQueuedTaskCount());
System.
out
.printf(
"Main: Thread Steal Count: %d
\n
"
, pool.getStealCount());
try
{
TimeUnit.
MILLISECONDS
.sleep(
5
);
}
catch
(InterruptedException e) {
e.printStackTrace();
}
}
while
(!task.isDone());
pool.shutdown();
try
{
pool.awaitTermination(
1
, TimeUnit.
DAYS
);
System.
out
.printf(
"Main: The word appears %d in the document"
, task.get());
}
catch
(InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
}
异步运行任务
public class
FolderProcessor
extends
RecursiveTask<List<String>> {
private static final long
serialVersionUID
=
8145504519186061283L
;
private
String
path
;
private
String
extension
;
public
FolderProcessor(String path, String extension) {
this
.
path
= path;
this
.
extension
= extension;
}
@Override
protected
List<String> compute() {
List<String> filePathList =
new
ArrayList<>();
List<FolderProcessor> tasks =
new
ArrayList<>();
File file =
new
File(
path
);
File content[] = file.listFiles();
if
(content !=
null
) {
for
(
int
i =
0
; i < content.
length
; i++) {
if
(content[i].isDirectory()) {
FolderProcessor task =
new
FolderProcessor(content[i].getAbsolutePath(),
extension
);
task.fork();
tasks.add(task);
}
else
{
if
(checkFile(content[i].getName())) {
filePathList.add(content[i].getAbsolutePath());
}
}
}
if
(tasks.size() >
50
) {
System.
out
.printf(
"%s: %d tasks ran.
\n
"
, file.getAbsolutePath(), tasks.size());
}
addResultsFromTasks(filePathList, tasks);
}
return
filePathList;
}
private void
addResultsFromTasks(List<String> filePathList, List<FolderProcessor> tasks) {
// tasks.forEach((item) -> list.addAll(item.join())); 使用 Lambda 一句话完成任务
for
(FolderProcessor folderProcessor : tasks) {
filePathList.addAll(folderProcessor.join());
}
}
private boolean
checkFile(String name) {
return
name.endsWith(
extension
);
}
public static void
main(String[] args) {
ForkJoinPool pool =
new
ForkJoinPool();
FolderProcessor system =
new
FolderProcessor(
"/Users/mew/Desktop/AllmyFile"
,
"log"
);
FolderProcessor apps =
new
FolderProcessor(
"/Users/mew/Desktop/leanote-master"
,
"log"
);
FolderProcessor documents =
new
FolderProcessor(
"/Users/mew/Desktop/Company_Workspace"
,
"log"
);
pool.execute(system);
pool.execute(apps);
pool.execute(documents);
do
{
System.
out
.printf(
"******************************
\n
"
);
System.
out
.printf(
"Main: Parallelism: %d
\n
"
, pool.getParallelism());
System.
out
.printf(
"Main: Active Threads: %d
\n
"
, pool.getActiveThreadCount());
System.
out
.printf(
"Main: Task Count: %d
\n
"
, pool.getQueuedTaskCount());
System.
out
.printf(
"Main: Steal Count: %d
\n
"
, pool.getStealCount());
try
{
TimeUnit.
SECONDS
.sleep(
1
);
}
catch
(InterruptedException e) {
e.printStackTrace();
}
}
while
((!system.isDone()) || (!apps.isDone()) || (!documents.isDone()));
pool.shutdown();
List<String> results = system.join();
System.
out
.printf(
"System: %d files found.
\n
"
, results.size());
results = apps.join();
System.
out
.printf(
"Apps: %d files found.
\n
"
, results.size());
results = documents.join();
System.
out
.printf(
"Documents: %d files found.
\n
"
, results.size());
}
}
取消任务
public class
TaskManager {
private
List<ForkJoinTask<Integer>>
tasks
;
public
TaskManager() {
tasks
=
new
ArrayList<>();
}
public void
addTask(ForkJoinTask<Integer> task) {
tasks
.add(task);
}
public void
cancleTasks(ForkJoinTask<Integer> cancelTask) {
for
(ForkJoinTask<Integer> task :
tasks
) {
if
(task != cancelTask) {
task.cancel(
true
);
((SearchNumberTask) task).writeCancelMessage();
}
}
}
}
public class
SearchNumberTask
extends
RecursiveTask<Integer> {
private int
[]
numbers
;
private int
start
,
end
;
private int
number
;
private
TaskManager
manager
;
private final static int
NOT_FOUND
= -
1
;
public
SearchNumberTask(
int
[] numbers,
int
start,
int
end,
int
number, TaskManager manager) {
this
.
numbers
= numbers;
this
.
start
= start;
this
.
end
= end;
this
.
number
= number;
this
.
manager
= manager;
}
@Override
protected
Integer compute() {
System.
out
.println(
"Task: "
+
start
+
":"
+
end
);
int
ret;
if
(
end
-
start
>
10
) {
ret = launchTasks();
}
else
{
ret = lookForNumber();
}
return
ret;
}
private int
lookForNumber() {
for
(
int
i =
start
; i <
end
; i++) {
if
(
numbers
[i] ==
number
) {
System.
out
.printf(
"Task: Number %d found in position %d
\n
"
,
number
, i);
manager
.cancleTasks(
this
);
return
i;
}
try
{
TimeUnit.
SECONDS
.sleep(
1
);
}
catch
(InterruptedException e) {
e.printStackTrace();
}
}
return
NOT_FOUND
;
}
private int
launchTasks() {
int
mid = (
start
+
end
) /
2
;
SearchNumberTask task1 =
new
SearchNumberTask(
numbers
,
start
, mid,
number
,
manager
);
SearchNumberTask task2 =
new
SearchNumberTask(
numbers
, mid,
end
,
number
,
manager
);
manager
.addTask(task1);
manager
.addTask(task2);
task1.fork();
task2.fork();
int
returnValue = task1.join();
if
(returnValue != -
1
) {
return
returnValue;
}
returnValue = task2.join();
return
returnValue;
}
public void
writeCancelMessage() {
System.
out
.printf(
"Task: Cancelled task from %d to %d"
,
start
,
end
);
}
public static int
[] generateArray(
int
size) {
int
[] array =
new int
[size];
Random random =
new
Random();
for
(
int
i =
0
; i < size; i++) {
array[i] = random.nextInt(
10
);
}
return
array;
}
public static void
main(String[] args) {
int
[] array =
generateArray
(
1000
);
TaskManager manager =
new
TaskManager();
ForkJoinPool pool =
new
ForkJoinPool();
SearchNumberTask task =
new
SearchNumberTask(array,
0
,
1000
,
5
, manager);
pool.execute(task);
pool.shutdown();
try
{
pool.awaitTermination(
1
, TimeUnit.
DAYS
);
}
catch
(InterruptedException e) {
e.printStackTrace();
}
}
}