一、CountDownLatch基本介绍
CountDownLatch是JDK提供的一种同步工具,使用它的API可以方便的实现一个或者多个线程等待其他一个或多个线程执行某项任务完成的需求。
CountDownLatch通过给定的count
进行初始化,然后调用await
方法阻塞,直到有线程通过调用countDown
方法使count
的计数值达到零,之后被阻塞的线程就会重新进入可运行状态。
CountDownLatch的count
值只能赋值一次,不能重复使用,这是其于另一种工具类CyclicBarrier.
的不同之处。
二、主要用途
1.启动信号
当我们把count
值设置为1时,可以方便的实现开关控制,我们可以让所有的线程在await
处等待,当它们都已准备就绪时,再调用countDown
方法,这样所有准备就绪的线程就能一起同时进行,就好像是在模拟一种高并发的场景。
2.完成信号
当我们需要对一项任务完成N次时,可以同时开启N个线程,并把count
值也设置为N,然后主线程在await
处等待,之后每当一个线程完成任务后就调用countDown
,最后当所有线程都已完成任务,在await
处等待的线程即可对所有任务结果进行处理,得到想要的结果。
三、场景实现
1.启动信号、完成信号
startSignal
和doneSignal
分别演示了启动信号和完成信号的控制。
import java.util.concurrent.CountDownLatch;
public class CountDownLatchTest {
private static final int WORK_NUMBER = 5;
public static void main(String[] args) throws InterruptedException {
CountDownLatchTest countDownLatchTest = new CountDownLatchTest();
countDownLatchTest.test();
}
private void test() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(WORK_NUMBER);
for (int i = 0; i < WORK_NUMBER; ++i)
new Thread(new Worker(startSignal, doneSignal)).start();
prepare();
startSignal.countDown();
doneSignal.await();
after();
}
private void after() {
System.out.println("完成后执行!");
}
private void prepare() {
System.out.println("执行前准备!");
}
static class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;
Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}
public void run() {
try {
startSignal.await();
doWork();
doneSignal.countDown();
} catch (InterruptedException ignored) {
}
}
void doWork() {
System.out.println("do work!");
}
}
}
2.并行计算
利用CountDownLatch也可以方便的实现多线程并行计算。
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
public class CountDownLatchTest2 {
private static final int WORK_NUMBER = 5;
private static ConcurrentHashMap<Long, Integer> concurrentHashMap = new ConcurrentHashMap<>();
public static void main(String[] args) throws InterruptedException {
CountDownLatchTest2 countDownLatchTest = new CountDownLatchTest2();
countDownLatchTest.test();
}
private void test() throws InterruptedException {
CountDownLatch parallel = new CountDownLatch(WORK_NUMBER);
for (int i = 0; i < WORK_NUMBER; ++i) {
new Thread(new Worker(parallel)).start();
}
parallel.await();
Integer sum = concurrentHashMap.values().stream().reduce(Integer::sum).get();
System.out.println("汇总所有线程计算结果:" + sum);
}
static class Worker implements Runnable {
private final CountDownLatch parallel;
Worker(CountDownLatch parallel) {
this.parallel = parallel;
}
public void run() {
doWork(Thread.currentThread().getId());
parallel.countDown();
}
void doWork(long id) {
try {
// 假设每个线程都计算一部分,并将结果保存在value中
Thread.sleep(1000);
concurrentHashMap.put(id, 1);
} catch (InterruptedException ignored) {
}
}
}
}