并行计算
分支/合并框架的目的是以递归方式将可以并行的任务拆分成更小的任务,然后将每个子任务的结果合并起来生成整体结果。
大概步骤如下
- 任务拆分fork
- 线程自动拿去任务,处理执行compute方法
- 规约处理结果join
举例说明
打印10次结果
public class PrintTest {
/**
* <p>Title: main</p>
* <p>Description: </p>
* @param args
*/
public static void main(String[] args) {
for (int i = 0; i < 10; i++) {
System.out.println(String.format("%s : %s", Thread.currentThread().getName(), i));
}
}
}
处理结果:
main : 0
main : 1
main : 2
main : 3
main : 4
main : 5
main : 6
main : 7
main : 8
main : 9
引入fork/join框架计算,这里使用RecursiveAction
类处理,定义任务拆分类
public class PrintRecursiveAction extends RecursiveAction {
/** serialVersionUID*/
private static final long serialVersionUID = -3070198564230327332L;
private int start;
private int end;
/**
* <p>Title: </p>
* <p>Description: </p>
* @param times
*/
public PrintRecursiveAction(int start, int end) {
if (end - start <= 2) {
System.out.println(String.format("[start,end] : [%s,%s];", start, end));
}
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if (end - start <= 2) {
for (int i = start; i < end; i++) {
System.out.println(String.format("%s : %s", Thread.currentThread().getName(), i));
}
} else {
// 二分
int middle = (start + end) / 2;
PrintRecursiveAction leftTask = new PrintRecursiveAction(start, middle);
PrintRecursiveAction rightTask = new PrintRecursiveAction(middle, end);
// 拆分任务
leftTask.fork();
rightTask.fork();
// 规约执行结果
// leftTask.join();
// rightTask.compute();
}
}
public static void main(String[] args) throws InterruptedException {
ForkJoinPool pool = new ForkJoinPool();
pool.invoke(new PrintRecursiveAction(0, 10));
System.out.println(String.format("%s", Thread.currentThread().getName()));
pool.shutdown();
// 如果不延迟10秒,这里只是拆分的任务,线程还未来得及处理
TimeUnit.SECONDS.sleep(10);
}
}
处理结果:
main
[start,end] : [0,2];
[start,end] : [2,3];
[start,end] : [3,5];
[start,end] : [5,7];
ForkJoinPool-1-worker-2 : 3
ForkJoinPool-1-worker-0 : 0
ForkJoinPool-1-worker-0 : 1
ForkJoinPool-1-worker-2 : 4
[start,end] : [7,8];
ForkJoinPool-1-worker-1 : 5
ForkJoinPool-1-worker-0 : 2
ForkJoinPool-1-worker-1 : 6
[start,end] : [8,10];
ForkJoinPool-1-worker-3 : 8
ForkJoinPool-1-worker-1 : 7
ForkJoinPool-1-worker-3 : 9
对于处理结果,我们可以看出,当前任务被拆除了[0,2]
、[2,3]
、[3,5]
、[5,7]
、[7,8]
、[8,10]
区间,对应的线程关系为下图
线程 | 处理空间 |
---|---|
ForkJoinPool-1-worker-1 | [0,2] 、[5,7] |
ForkJoinPool-1-worker-2 | [3,5] |
ForkJoinPool-1-worker-3 | [8,10] |
ForkJoinPool-1-worker-0 | main |
以上符合之前的图解拆分,这里面ForkJoinPool-1-worker-1
获取到两个任务。
工作窃取
上述是工作窃取思想
分出大量的小任务一般来说都是一个好的选择。这是因为,理想情况下,划分并行任务时,应该让每个任务都用完全相同的时间完成,让所有的CPU内核都同样繁忙。不幸的是,实际中,每个子任务所花的时间可能天差地别,要么是因为划分策略效率低,要么是有不可预知的原因,比如磁盘访问慢,或是需要和外部服务协调执行。
分支/合并框架工程用一种称为工作窃取(work stealing)的技术来解决这个问题。在实际应用中,这意味着这些任务差不多被平均分配到 ForkJoinPool 中的所有线程上。每个线程都为分配给它的任务保存一个双向链式队列,每完成一个任务,就会从队列头上取出下一个任务开始执行。基于前面所述的原因,某个线程可能早早完成了分配给它的所有任务,也就是它的队列已经空了,而其他的线程还很忙。这时,这个线程并没有闲下来,而是随机选了一个别的线程,从队列的尾巴上“偷走”一个任务。这个过程一直继续下去,直到所有的任务都执行完毕,所有的队列都清空。这就是为什么要划成许多小任务而不是少数几个大任务,这有助于更好地在工作线程之间平衡负载。
一般来说,这种工作窃取算法用于在池中的工作线程之间重新分配和平衡任务。展示了这个过程。当工作线程队列中有一个任务被分成两个子任务时,一个子任务就被闲置的工作线程“偷走”了。如前所述,这个过程可以不断递归,直到规定子任务应顺序执行的条件为真。
fork/join框架代码使用
主要是提供了两个抽象基类RecursiveAction
、RecursiveTask
。这两个类都是继承ForkJoinTask<V>
,都需要实现protected abstract V compute();
方法。
类 | 返回值 |
---|---|
RecursiveAction | 无返回值 |
RecursiveTask | 返回值V |
RecursiveAction使用
其中RecursiveAction
上面已经说明过了。这里再给出官网的例子即可,不再举例说明:https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/RecursiveAction.html
RecursiveTask使用
计算:1+2+3+...+100000=?
,具体实现如下
public class SumRecursiveTask extends RecursiveTask<Long> {
/** serialVersionUID */
private static final long serialVersionUID = 5927901167765240121L;
private final long[] numbers;
private final int start;
private final int end;
final long THRESHOLD = 10_000;
// 统计拆分次数
private static AtomicInteger splitTime = new AtomicInteger();
public SumRecursiveTask(long[] numbers) {
this(numbers, 0, numbers.length);
}
private SumRecursiveTask(long[] numbers, int start, int end) {
this.numbers = numbers;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
if (length <= THRESHOLD) {
return LongStream.range(start, end).sum();
} else {
synchronized (SumRecursiveTask.class) {
splitTime.addAndGet(1);
System.out.println(String.format("%s -> [start, end]:[%s,%s], times: %s", Thread.currentThread().getName(), start, end, splitTime));
}
SumRecursiveTask leftTask = new SumRecursiveTask(numbers, start, start + length / 2);
SumRecursiveTask rightTask = new SumRecursiveTask(numbers, start + length / 2, end);
// 1. 拆分任务
leftTask.fork();
rightTask.fork();
// 3. 规约结果
Long leftResult = leftTask.join();
Long rightResult = rightTask.join();
return leftResult + rightResult;
}
}
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool();
Long result = pool.invoke(new SumRecursiveTask(LongStream.rangeClosed(0, 100_000).toArray()));
System.out.println(result);
pool.shutdown();
}
}
处理结果:
ForkJoinPool-1-worker-1 -> [start, end]:[0,100001], times: 1
ForkJoinPool-1-worker-2 -> [start, end]:[50000,100001], times: 2
ForkJoinPool-1-worker-3 -> [start, end]:[0,50000], times: 3
ForkJoinPool-1-worker-3 -> [start, end]:[0,25000], times: 4
ForkJoinPool-1-worker-3 -> [start, end]:[0,12500], times: 5
ForkJoinPool-1-worker-2 -> [start, end]:[50000,75000], times: 6
ForkJoinPool-1-worker-2 -> [start, end]:[50000,62500], times: 7
ForkJoinPool-1-worker-1 -> [start, end]:[25000,50000], times: 8
ForkJoinPool-1-worker-1 -> [start, end]:[25000,37500], times: 9
ForkJoinPool-1-worker-0 -> [start, end]:[75000,100001], times: 10
ForkJoinPool-1-worker-0 -> [start, end]:[75000,87500], times: 11
ForkJoinPool-1-worker-2 -> [start, end]:[62500,75000], times: 12
ForkJoinPool-1-worker-3 -> [start, end]:[12500,25000], times: 13
ForkJoinPool-1-worker-1 -> [start, end]:[37500,50000], times: 14
ForkJoinPool-1-worker-0 -> [start, end]:[87500,100001], times: 15
5000050000
我这面我们看到,也是一共四个线程(其中包括一个main主线程)
,为什么都是四个,而不是其他个呐?因为我的电脑是4核,最优线程数。测试代码如下:
public class CountTest {
/**
* <p>Title: main</p>
* <p>Description: </p>
* @param args
*/
public static void main(String[] args) {
int availableProcessors = Runtime.getRuntime().availableProcessors();
System.out.println(availableProcessors);
}
}
对上述例子优化一下,减少一个线程的拆分浪费,具体compute
方法优化如下:
@Override
protected Long compute() {
int length = end - start;
if (length <= THRESHOLD) {
return LongStream.range(start, end).sum();
} else {
synchronized (SumRecursiveTask.class) {
splitTime.addAndGet(1);
System.out.println(String.format("%s -> [start, end]:[%s,%s], times: %s", Thread.currentThread().getName(), start, end, splitTime));
}
SumRecursiveTask leftTask = new SumRecursiveTask(numbers, start, start + length / 2);
SumRecursiveTask rightTask = new SumRecursiveTask(numbers, start + length / 2, end);
// 1. leftTask拆分任务,rightTask直接执行compute方法,不再拆分成一个单独的任务
leftTask.fork();
// 3. 规约结果
Long leftResult = leftTask.join();
Long rightResult = rightTask.compute();
return leftResult + rightResult;
}
到这里基本上可以明白基本使用了