1、创建一个Fork/Join池
ForkJoinPool 逻辑
If (problem size < default size){
tasks=divide(task);
execute(tasks);
}
else {
resolve problem using another algorithm;
}
public class MyForkJoinPool1 {
public static void main(String[] args) {
ProductListGenerator generator = new ProductListGenerator();
List<Product> products = generator.generate(10000);
Task task = new Task(products, 0, products.size(), 0.20);
ForkJoinPool pool = new ForkJoinPool();
pool.execute(task);
do {
System.out.printf("Main: Thread Count: %d\n",
pool.getActiveThreadCount());
System.out.printf("Main: Thread Steal: %d\n", pool.getStealCount());
System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
try {
TimeUnit.MILLISECONDS.sleep(5);
} catch (InterruptedException e) {
e.printStackTrace();
}
} while (!task.isDone());
pool.shutdown();
if (task.isCompletedNormally()) {
System.out.printf("Main: The process has completed normally.\n");
}
for (int i = 0; i < products.size(); i++) {
Product product = products.get(i);
if (product.getPrice() != 12) {
System.out.printf("Product %s: %f\n", product.getName(),
product.getPrice());
}
}
System.out.println("Main: End of the program.\n");
}
}
class Product {
private String name;
private double price;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public double getPrice() {
return price;
}
public void setPrice(double price) {
this.price = price;
}
}
class ProductListGenerator {
public List<Product> generate(int size) {
List<Product> ret = new ArrayList<Product>();
for (int i = 0; i < size; i++) {
Product product = new Product();
product.setName("Product" + i);
product.setPrice(10);
ret.add(product);
}
return ret;
}
}
class Task extends RecursiveAction {
private static final long serialVersionUID = 1L;
private List<Product> products;
private int first;
private int last;
private double increment;
public Task(List<Product> products, int first, int last, double increment) {
this.products = products;
this.first = first;
this.last = last;
this.increment = increment;
}
@Override
protected void compute() {
if (last - first < 10) {
updatePrices();
} else {
int middle = (last + first) / 2;
System.out.printf("Task: Pending tasks:%s\n", getQueuedTaskCount());
Task t1 = new Task(products, first, middle + 1, increment);
Task t2 = new Task(products, middle + 1, last, increment);
invokeAll(t1, t2);
}
}
private void updatePrices() {
for (int i = first; i < last; i++) {
Product product = products.get(i);
product.setPrice(product.getPrice() * (1 + increment));
}
}
}
2.加入任务的结果
Fork/Join框架提供了执行返回一个结果的任务的能力。这些任务的类型是实现了RecursiveTask类。这个类继承了ForkJoinTask类和实现了执行者框架提供的Future接口。
If (problem size < size){
tasks=Divide(task);
execute(tasks);
groupResults()
return result;
} else {
resolve problem;
return result;
}
如果这个任务必须解决一个超过预定义大小的问题,你应该将这个任务分解成更多的子任务,并且用Fork/Join框架来执行这些子任务。当这些子任务完成执行,发起的任务将获得所有子任务产生的结果 ,对这些结果进行分组,并返回最终的结果。最终,当在池中执行的发起的任务完成它的执行,你将获取整个问题地最终结果。
class MyForkJoinPool2 {
public static void main(String[] args) {
DocumentMock mock = new DocumentMock();
String[][] document = mock.generateDocument(100, 1000, "the");
DocumentTask task = new DocumentTask(document, 0, 100, "the");
ForkJoinPool pool = new ForkJoinPool();
pool.execute(task);
do {
System.out.printf("******************************************\n");
System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
System.out.printf("Main: Active Threads: %d\n",
pool.getActiveThreadCount());
System.out.printf("Main: Task Count: %d\n",
pool.getQueuedTaskCount());
System.out.printf("Main: Steal Count: %d\n", pool.getStealCount());
System.out.printf("******************************************\n");
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
} while (!task.isDone());
pool.shutdown();
try {
System.out.printf("Main: The word appears %d in the document",
task.get());
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
try {
System.out.printf("Main: The word appears %d in the document",
task.get());
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
}
class DocumentMock {
private String words[] = { "the", "hello", "goodbye", "packt", "java",
"thread", "pool", "random", "class", "main" };
public String[][] generateDocument(int numLines, int numWords, String word) {
int counter = 0;
String document[][] = new String[numLines][numWords];
Random random = new Random();
for (int i = 0; i < numLines; i++) {
for (int j = 0; j < numWords; j++) {
int index = random.nextInt(words.length);
document[i][j] = words[index];
if (document[i][j].equals(word)) {
counter++;
}
}
}
System.out.println("DocumentMock: The word appears " + counter
+ " times in the document");
return document;
}
}
class DocumentTask extends RecursiveTask<Integer> {
/**
*
*/
private static final long serialVersionUID = -7632107634821261866L;
private String document[][];
private int start, end;
private String word;
public DocumentTask(String document[][], int start, int end, String word) {
this.document = document;
this.start = start;
this.end = end;
this.word = word;
}
@Override
protected Integer compute() {
int result = 0;
if (end - start < 10) {
result = processLines(document, start, end, word);
} else {
int mid = (start + end) / 2;
DocumentTask task1 = new DocumentTask(document, start, mid, word);
DocumentTask task2 = new DocumentTask(document, mid, end, word);
invokeAll(task1, task2);
try {
result = groupResults(task1.get(), task2.get());
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
return result;
}
private Integer processLines(String[][] document, int start, int end,
String word) {
List<LineTask> tasks = new ArrayList<LineTask>();
for (int i = start; i < end; i++) {
LineTask task = new LineTask(document[i], 0, document[i].length,
word);
tasks.add(task);
}
invokeAll(tasks);
int result = 0;
for (int i = 0; i < tasks.size(); i++) {
LineTask task = tasks.get(i);
try {
result = result + task.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
return new Integer(result);
}
private Integer groupResults(Integer number1, Integer number2) {
Integer result;
result = number1 + number2;
return result;
}
}
class LineTask extends RecursiveTask<Integer> {
private static final long serialVersionUID = 1L;
private String line[];
private int start, end;
private String word;
public LineTask(String line[], int start, int end, String word) {
this.line = line;
this.start = start;
this.end = end;
this.word = word;
}
@Override
protected Integer compute() {
Integer result = null;
if (end - start < 100) {
result = count(line, start, end, word);
} else {
int mid = (start + end) / 2;
LineTask task1 = new LineTask(line, start, mid, word);
LineTask task2 = new LineTask(line, mid, end, word);
invokeAll(task1, task2);
try {
result = groupResults(task1.get(), task2.get());
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
return result;
}
private Integer count(String[] line, int start, int end, String word) {
int counter;
counter = 0;
for (int i = start; i < end; i++) {
if (line[i].equals(word)) {
counter++;
}
}
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
return counter;
}
private Integer groupResults(Integer number1, Integer number2) {
Integer result;
result = number1 + number2;
return result;
}
}
3. 异步方式
当你在ForkJoinPool中执行ForkJoinTask时,你可以使用同步或异步方式来实现。当你使用同步方式时,提交任务给池的方法直到提交的任务完成它的执行,才会返回结果。当你使用异步方式时,提交任务给执行者的方法将立即返回,所以这个任务可以继续执行。
你应该意识到这两个方法有很大的区别,当你使用同步方法,调用这些方法(比如:invokeAll()方法)的任务将被阻塞,直到提交给池的任务完成它的执行。这允许ForkJoinPool类使用work-stealing算法,分配一个新的任务给正在执行睡眠任务的工作线程。反之,当你使用异步方法(比如:fork()方法),这个任务将继续它的执行,所以ForkJoinPool类不能使用work-stealing算法来提高应用程序的性能。在这种情况下,只有当你调用join()或get()方法来等待任务的完成时,ForkJoinPool才能使用work-stealing算法。
public class MyForkJoinPool3 {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool();
FolderProcessor system = new FolderProcessor("C:\\Windows", "log");
FolderProcessor apps = new FolderProcessor("C:\\Program Files", "log");
FolderProcessor documents = new FolderProcessor(
"C:\\Documents And Settings", "log");
pool.execute(system);
pool.execute(apps);
pool.execute(documents);
do {
System.out.printf("******************************************\n");
System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
System.out.printf("Main: Active Threads: %d\n",
pool.getActiveThreadCount());
System.out.printf("Main: Task Count: %d\n",
pool.getQueuedTaskCount());
System.out.printf("Main: Steal Count: %d\n", pool.getStealCount());
System.out.printf("******************************************\n");
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
} while ((!system.isDone()) || (!apps.isDone())
|| (!documents.isDone()));
pool.shutdown();
List<String> results;
results = system.join();
System.out.printf("System: %d files found.\n", results.size());
results = apps.join();
System.out.printf("Apps: %d files found.\n", results.size());
results = documents.join();
System.out.printf("Documents: %d files found.\n", results.size());
}
}
class FolderProcessor extends RecursiveTask<List<String>> {
private String path;
private String extension;
public FolderProcessor(String path, String extension) {
this.path = path;
this.extension = extension;
}
@Override
protected List<String> compute() {
List<String> list = new ArrayList<>();
List<FolderProcessor> tasks = new ArrayList<>();
File file = new File(path);
File content[] = file.listFiles();
if (content != null) {
for (int i = 0; i < content.length; i++) {
if (content[i].isDirectory()) {
FolderProcessor task = new FolderProcessor(
content[i].getAbsolutePath(), extension);
task.fork();
tasks.add(task);
} else {
if (checkFile(content[i].getName())) {
list.add(content[i].getAbsolutePath());
}
}
}
if (tasks.size() > 50) {
System.out.printf("%s: %d tasks ran.\n",
file.getAbsolutePath(), tasks.size());
}
addResultsFromTasks(list, tasks);
}
return list;
}
private void addResultsFromTasks(List<String> list,
List<FolderProcessor> tasks) {
for (FolderProcessor item : tasks) {
list.addAll(item.join());
}
}
private boolean checkFile(String name) {
return name.endsWith(extension);
}
}
4. 在任务中抛出异常
在ForkJoinTask类的compute()方法中,你不能抛出任何已检查异常,因为在这个方法的实现中,它没有包含任何抛出(异常)声明。你必须包含必要的代码来处理异常。但是,你可以抛出(或者它可以被任何方法或使用内部方法的对象抛出)一个未检查异常。ForkJoinTask和ForkJoinPool类的行为与你可能的期望不同。程序不会结束执行,并且你将不会在控制台看到任何关于异常的信息。它只是被吞没,好像它没抛出(异常)。你可以使用ForkJoinTask类的一些方法,得知一个任务是否抛出异常及其异常种类。
public class MyForkJoinPool4 {
public static void main(String[] args) {
int array[] = new int[100];
MyTask task = new MyTask(array, 0, 100);
ForkJoinPool pool = new ForkJoinPool();
pool.execute(task);
pool.shutdown();
try {
pool.awaitTermination(1, TimeUnit.DAYS);
} catch (InterruptedException e) {
e.printStackTrace();
}
if (task.isCompletedAbnormally()) {
System.out.printf("Main: An exception has ocurred\n");
System.out.printf("Main: %s\n", task.getException());
}
System.out.printf("Main: Result: %d", task.join());
}
}
class MyTask extends RecursiveTask<Integer> {
private int array[];
private int start, end;
public MyTask(int array[], int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
System.out.printf("Task: Start from %d to %d\n", start, end);
if (end - start < 10) {
if ((3 > start) && (3 < end)) {
throw new RuntimeException("This task throws an"
+ "Exception: Task from " + start + " to " + end);
}
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
} else {
int mid = (end + start) / 2;
MyTask task1 = new MyTask(array, start, mid);
MyTask task2 = new MyTask(array, mid, end);
invokeAll(task1, task2);
}
System.out.printf("Task: End form %d to %d\n", start, end);
return 0;
}
}
5.取消任务
当你在一个ForkJoinPool类中执行ForkJoinTask对象,在它们开始执行之前,你可以取消执行它们。ForkJoinTask类提供cancel()方法用于这个目的。当你想要取消一个任务时,有一些点你必须考虑一下,这些点如下:
- ForkJoinPool类并没有提供任何方法来取消正在池中运行或等待的所有任务。
- 当你取消一个任务时,你不能取消一个已经执行的任务。
public final class MyForkJoinPool5 {
public static void main(String[] args) {
ArrayGenerator generator = new ArrayGenerator();
int array[] = generator.generateArray(1000);
TaskManager manager = new TaskManager();
ForkJoinPool pool = new ForkJoinPool();
SearchNumberTask task = new SearchNumberTask(array, 0, 1000, 5, manager);
pool.execute(task);
pool.shutdown();
try {
pool.awaitTermination(1, TimeUnit.DAYS);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.printf("Main: The program has finished\n");
}
}
class ArrayGenerator {
public int[] generateArray(int size) {
int array[] = new int[size];
Random random = new Random();
for (int i = 0; i < size; i++) {
array[i] = random.nextInt(10);
}
return array;
}
}
class TaskManager {
private List<ForkJoinTask<Integer>> tasks;
public TaskManager() {
tasks = new ArrayList<>();
}
public void addTask(ForkJoinTask<Integer> task) {
tasks.add(task);
}
public void cancelTasks(ForkJoinTask<Integer> cancelTask) {
for (ForkJoinTask<Integer> task : tasks) {
if (task != cancelTask) {
task.cancel(true);
((SearchNumberTask) task).writeCancelMessage();
}
}
}
}
class SearchNumberTask extends RecursiveTask<Integer> {
private int numbers[];
private int start, end;
private int number;
private TaskManager manager;
private final static int NOT_FOUND = -1;
public SearchNumberTask(int numbers[], int start, int end, int number,
TaskManager manager) {
this.numbers = numbers;
this.start = start;
this.end = end;
this.number = number;
this.manager = manager;
}
@Override
protected Integer compute() {
System.out.println("Task: " + start + ":" + end);
int ret;
if (end - start > 10) {
ret = launchTasks();
} else {
ret = lookForNumber();
}
return ret;
}
private int lookForNumber() {
for (int i = start; i < end; i++) {
if (numbers[i] == number) {
System.out.printf("Task: Number %d found in position %d\n",
number, i);
manager.cancelTasks(this);
return i;
}
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
return NOT_FOUND;
}
private int launchTasks() {
int mid = (start + end) / 2;
SearchNumberTask task1 = new SearchNumberTask(numbers, start, mid,
number, manager);
SearchNumberTask task2 = new SearchNumberTask(numbers, mid, end,
number, manager);
manager.addTask(task1);
manager.addTask(task2);
task1.fork();
task2.fork();
int returnValue;
returnValue = task1.join();
if (returnValue != -1) {
return returnValue;
}
returnValue = task2.join();
return returnValue;
}
public void writeCancelMessage() {
System.out.printf("Task: Canceled task from %d to %d", start, end);
}
}