1 引言
在实际开发中,经常遇到这样一种场景:有一组线程,每个线程去持行自己的任务,现在想实现等所有线程都持行完毕后,再向下执行代码。比如我们开启一组线程去同步服务器的数据,点赞记录,下载记录,收藏记录等等用户信息,所有数据同步完成之后继续向下执行。
2 CountDownLatch的概念
CountDownLatch
是一个同步工具类,用来协调多个线程之间的同步,用于开启一组线程,使一个线程在等待这组线程完成各自工作之后,再继续执行。CountDownLatch
内部基于一个计数器实现。计数器初始值大于等于线程的数量(因为每个线程至少执行countDown()
方法一次)。当每一个线程完成自己任务后,计数器的值就会减一。当计数器的值为0时,表示所有的线程都已经完成指定的任务,然后在CountDownLatch上等待的线程就可以继续执行接下来的任务。
3 CountDownLatch的方法
CountDownLatch
的主要方法包括:
//调用await()方法的线程会被挂起,它会等待直到count值为0才继续执行
public void await() throws InterruptedException { };
//和await()类似,只不过等待一定的时间后count值还没变为0的话就会继续执行
public boolean await(long timeout, TimeUnit unit) throws InterruptedException { };
//将count值减1
public void countDown() { };
4 CountDownLatch的示例
假设现在有这样的场景:两个阅卷老师同时对两个班级的学生进行评分,最后需要统计出两个班级所有学生的最高分和最低分。这里可以利用CountDownLatch
的思想,启动两个线程分别对两个班级分数进行排序,在两个班级都排好序之后,再综合比较得出最大值和最小值。(当然这种场景启动两个线程消耗更大,效率较低。实际上可以通过插入排序、快速排序实现,本例方案的合理性不做讨论,只是为了演示CountDownLatch
的用法)。
附实现代码:
/**
* @author Carson Chu, 1965704869@qq.com
* @date 2020/4/5 12:41
* @description
*/
public class Main {
public static void main(String[] args) {
int[] score0 = {2, 6, 7, 9, 3};
int[] score1 = {8, 6, 7, 1, 5};
/* 初始化CountDownLatch,包含两个计数器 */
CountDownLatch countDownLatch = new CountDownLatch(2);
ExecutorService executorService0 = Executors.newCachedThreadPool();
/* 线程1对第一个数组进行排序 */
executorService0.submit(() -> {
Arrays.sort(score0);
/* 执行完计数器减1 */
countDownLatch.countDown();
});
ExecutorService executorService1 = Executors.newCachedThreadPool();
/* 线程2对第二个数组进行排序 */
executorService1.submit(() -> {
Arrays.sort(score1);
/* 执行完计数器减1 */
countDownLatch.countDown();
});
try {
/* 阻塞当前线程--main线程 */
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
int maxRes = score0[4] > score1[4] ? score0[4] : score1[4];
int minRes = score0[0] < score1[0] ? score0[0] : score1[0];
System.out.println("The maximum score is " + maxRes);
System.out.println("The minimum score is " + minRes);
}
}
5 CountDownLatch源码分析
CountDownLatch
底层是基于AbstractQueuedSynchronizer
(即大名鼎鼎的AQS)实现的,AQS主要有两种模式:共享模式(SHARED)和互斥模式(EXCLUSIVE),CountDownLatch
是基于共享模式的,内部封装了一个继承自AQS的类Sync
,附源码:
程序清单5-1 Sync源码
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
/* 调用await()方法的底层实现 */
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
/* 调用countDown()方法的底层实现 */
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
每次调用await()方法时候,会利用上述代码里的tryAcquireShared()
获取当前计数器数量,如果不为0,则执行如下代码,这段代码是干嘛的呢,就是把所有等待的线程放入一个CLH
队列中。
程序清单5-2 await()方法核心实现
/**
* Acquires in shared interruptible mode.
* @param arg the acquire argument
*/
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
每次调用countDown()方法时候,会利用程序清单5-1代码里的tryReleaseShared()
来判断当前计数器是否为0,如果为0则直接返回,如果不为0,则计数器减1,当计数器被减到0的时候,返回true,继续执行如下代码以唤醒线程(唤醒操作是基于LockSupport.unpark()实现的)。
关于LockSupport的机制及原理可参考我的另一篇博客,附传送门↓:
Java并发编程(二):LockSupport应用及原理分析
程序清单5-3 countDown()方法核心实现
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
//唤醒h线程
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
6 利用CountDownLatch
实现两个线程分别输出100以内的奇数和偶数
/**
* @author Carson
* @date 2020/7/4 13:03
*/
public class Main {
private static ExecutorService executorService = new ThreadPoolExecutor(2, 5, 1000L,
TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(1024));
private static CountDownLatch countDownLatch = new CountDownLatch(2);
private static volatile boolean flag = true;
private static AtomicInteger num = new AtomicInteger();
private static final Integer MAX = 100;
public static void main(String[] args) throws InterruptedException {
executorService.submit(new Runnable() {
@Override
public void run() {
while (num.get() <= MAX - 1) {
if (!flag) {
System.out.println(Thread.currentThread().getName() + "-->" + num.getAndIncrement());
flag = true;
}
}
countDownLatch.countDown();
}
});
executorService.submit(new Runnable() {
@Override
public void run() {
while (num.get() <= MAX) {
if (flag) {
System.out.println(Thread.currentThread().getName() + "-->" + num.getAndIncrement());
flag = false;
}
}
countDownLatch.countDown();
}
});
countDownLatch.await();
}
}
7 小结
要实现线程同步,除了CountDownLatch
外,Java还提供了了一种内存屏障的机制CyclicBarrier
,两者的区别在于:
CountDownLatch
是一个计数器,线程完成一个记录一个,计数器递减,只能用一次CyclicBarrier
的计数器更像一个阀门,需要所有线程都到达,然后继续执行,计数器递增,提供reset功能,支持重置多次使用。