CountDownLatch 一个同步辅助类,允许一个或多个线程等待,直到其它线程执行完成一组操作。它是 AQS 的共享模式的一种实现。
流程简介:CountDownLatch 必须通过数值 count 来初始化一个大于 0 的计数,任何线程调用 await 方法都会阻塞,直到其它线程调用 countDown 将计数从初始值减为 0,count 变为 0 时,所有阻塞在 await 方法的线程都会恢复运行。这个计数只能使用一次,如果需要循环使用考虑使用 CyclicBarrier 。
第一种用法示例:下面给出了两个类,其中一组 worker 线程使用了两个倒计数锁存器:
第一个类是一个启动信号,在 driver 为继续执行 worker 做好准备之前,它会阻止所有的 worker 继续执行。第二个类是一个完成信号,它允许 driver 在完成所有 worker 之前一直等待。
class Driver { // ...
void main() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(N);
for (int i = 0; i < N; ++i) // create and start threads
new Thread(new Worker(startSignal, doneSignal)).start();
doSomethingElse(); // don't let run yet
startSignal.countDown(); // let all threads proceed
doSomethingElse();
doneSignal.await(); // wait for all to finish
}
}
class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;
Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}
public void run() {
try {
startSignal.await();
doWork();
doneSignal.countDown();
} catch (InterruptedException ex) {} // return;
}
void doWork() { ... }
}
另一种典型用法是,将一个问题分成 N 个部分,用执行每个部分并让锁存器倒计数的 Runnable 来描述每个部分,然后将所有 Runnable 加入到 Executor 队列。当所有的子部分完成后,协调线程就能够通过 await。
class Driver2 { // ...
void main() throws InterruptedException {
CountDownLatch doneSignal = new CountDownLatch(N);
Executor e = ...
for (int i = 0; i < N; ++i) // create and start threads
e.execute(new WorkerRunnable(doneSignal, i));
doneSignal.await(); // wait for all to finish
}
}
class WorkerRunnable implements Runnable {
private final CountDownLatch doneSignal;
private final int i;
WorkerRunnable(CountDownLatch doneSignal, int i) {
this.doneSignal = doneSignal;
this.i = i;
}
public void run() {
try {
doWork(i);
doneSignal.countDown();
} catch (InterruptedException ex) {} // return;
}
void doWork() { ... }
}
看下 CountDownLatch 源码,它有一个实现了 AQS 的静态内部类 Sync 。
如何构造一个 CountDownLatch 。
CountDownLatch startSignal = new CountDownLatch(1);
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
//初始数值不能小于0
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
public class CountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
...
}
private final Sync sync;
}
举一个例子来分析:初始化一个计数为 10 的计数器 CountDownLatch ,然后启动十个线程,每个线程调用 await() 阻塞;然后启动五个线程,共调用 countDown() 十次释放掉计数,恢复前面启动的十个线程。
/**
* Created by Tangwz on 2019/6/25
*/
public class TestCountDownLatch {
private static CountDownLatch countDownLatch = new CountDownLatch(10);
private static class Thread1 extends Thread {
public Thread1(int i) {
super("Thread" + i);
}
@Override
public void run() {
try {
countDownLatch.await();
System.out.println(Thread.currentThread().getName() + "恢复运行");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
private static class Thread2 extends Thread {
@Override
public void run() {
countDownLatch.countDown();
countDownLatch.countDown();
}
}
public static void main(String[] args) throws InterruptedException, IllegalAccessException,
NoSuchFieldException {
for (int i = 0; i < 10; i++) {
Thread1 thread1 = new Thread1(i);
thread1.start();
//保证线程按照0-9的序号入队列
TimeUnit.MILLISECONDS.sleep(100);
}
//打印同步队列中的节点名称
printThreads();
TimeUnit.SECONDS.sleep(1);
System.out.println("开始唤醒线程");
for (int i = 0; i < 5; i++) {
Thread2 thread2 = new Thread2();
thread2.start();
}
}
private static void printThreads() throws NoSuchFieldException, IllegalAccessException {
Field sync = CountDownLatch.class.getDeclaredField("sync");
sync.setAccessible(true);
AbstractQueuedSynchronizer aqs = (AbstractQueuedSynchronizer) sync.get(countDownLatch);
ArrayList<Thread> threads = new ArrayList<>(aqs.getQueuedThreads());
for (int i = threads.size(); i > 0; i--) {
System.out.println(threads.get(i - 1).getName());
}
}
}
线程0、1、2依次进入同步队列的状态变化
执行 countDownLatch.await() 是怎么被阻塞的呢?
注意:下面涉及到的一些 AQS 方法也被其他并发工具类使用,而 CountDownLatch 不一定用得上,故步骤分析暂只考虑本类使用到的情况。
/**
* Causes the current thread to wait until the latch has counted down to
* zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
*/
public void await() throws InterruptedException {
//调用内部类sync,以共享方式获取锁,如果中断,中止
sync.acquireSharedInterruptibly(1);
}
/**
* Acquires in shared mode, aborting if interrupted. Implemented
* by first checking interrupt status, then invoking at least once
* {@link #tryAcquireShared}, returning on success. Otherwise the
* thread is queued, possibly repeatedly blocking and unblocking,
* invoking {@link #tryAcquireShared} until success or the thread
* is interrupted.
* @param arg the acquire argument.
* This value is conveyed to {@link #tryAcquireShared} but is
* otherwise uninterpreted and can represent anything
* you like.
* @throws InterruptedException if the current thread is interrupted
*/
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
//没有被 CountDown() 设置 state 为 0 前,所有线程进来都会获取锁失败
if (tryAcquireShared(arg) < 0)
//以一个共享可中断的节点获取锁
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
//只有状态为0的情况才返回1,其他都返回-1
return (getState() == 0) ? 1 : -1;
}
/**
* Acquires in shared interruptible mode.
* @param arg the acquire argument
*/
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
//十个线程都会进入这里,然后创建一个标识了共享的节点,添加到队尾
//这行流程按标记的序号进行
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
//1.线程0首次进来,获取不到锁,返回 r=-1
//3.线程0还是获取不到锁
//5.线程1-9等都会获取失败,然后依次加入队尾
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
//2.线程0将头节点状态设置为-1,返回循环
//4.线程0然后在这里会阻塞,等待被 countDown 唤醒
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
接下来分析启动的五个线程调用 countDownLatch.countDown() 是怎么一步步最终释放锁,唤醒共享节点的。
/**
* Decrements the count of the latch, releasing all waiting threads if
* the count reaches zero.
*
* <p>If the current count is greater than zero then it is decremented.
* If the new count is zero then all waiting threads are re-enabled for
* thread scheduling purposes.
*
* <p>If the current count equals zero then nothing happens.
*/
public void countDown() {
//每次执行一次只减一,当减至0时代表释放了锁,需要唤醒等待节点
sync.releaseShared(1);
}
/**
* Releases in shared mode. Implemented by unblocking one or more
* threads if {@link #tryReleaseShared} returns true.
*
* @param arg the release argument. This value is conveyed to
* {@link #tryReleaseShared} but is otherwise uninterpreted
* and can represent anything you like.
* @return the value returned from {@link #tryReleaseShared}
*/
public final boolean releaseShared(int arg) {
//调用 CountDownLatch.Sync.tryReleaseShared()
if (tryReleaseShared(arg)) {
//最终有且仅有一个线程能进入到这里执行
//就是最后执行 compareAndSetState(1, 0) 成功的那个线程
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
//可能会有多个线程同时释放锁,需要考虑并发
for (;;) {
int c = getState();
if (c == 0)
//状态不能小于0,之前的状态已经为0了需要返回释放锁失败
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
//只有状态被减至0才返回true
return nextc == 0;
}
}
/**
* Release action for shared mode -- signals successor and ensures
* propagation. (Note: For exclusive mode, release just amounts
* to calling unparkSuccessor of head if it needs signal.)
*/
private void doReleaseShared() {
/*
* Ensure that a release propagates, even if there are other
* in-progress acquires/releases. This proceeds in the usual
* way of trying to unparkSuccessor of head if it needs
* signal. But if it does not, status is set to PROPAGATE to
* ensure that upon release, propagation continues.
* Additionally, we must loop in case a new node is added
* while we are doing this. Also, unlike other uses of
* unparkSuccessor, we need to know if CAS to reset status
* fails, if so rechecking.
*/
//这个方法被用来唤醒下个节点,并传递状态 Node.PROPAGATE
//先按照本例子最简单的逻辑来分析,即最后一个线程执行完 countDown 然后依次唤醒 0-9 这10个节点
//实际上不会依次唤醒这十个节点,要是运行例子程序会发现输出的线程名称是乱的,不是从0-9打印,原因后面分析
for (;;) {
//h为线程为空的头结点
Node h = head;
//1.因为队列不为空,有十个节点,判断成功
//6.被唤醒的thread0也会进来执行,然后唤醒下个节点thread1,thread1再唤醒thread2...
if (h != null && h != tail) {
int ws = h.waitStatus;
//2.头节点状态为-1
if (ws == Node.SIGNAL) {
//3.设置头结点状态为0,CAS失败的情况后面分析
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
//4.唤醒下一个节点thread0
unparkSuccessor(h);
}
else if (ws == 0 &&
//CAS失败的情况,以及 Node.PROPAGATE 有啥用处参考后面的文章 Semaphore 源码分析
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
//5.退出循环
break;
}
}
所有阻塞在 countDownLatch.await() 的线程需要被唤醒
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
//2.前驱是头结点
if (p == head) {
//3.此时 status 已经减成了 0,这里 r=1
int r = tryAcquireShared(arg);
if (r >= 0) {
//3.获取共享锁成功后,设置头,唤醒下个节点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
//1.thread0 首先被唤醒,没有中断
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
/**
* Sets head of queue, and checks if successor may be waiting
* in shared mode, if so propagating if either propagate > 0 or
* PROPAGATE status was set.
*
* @param node the node
* @param propagate the return value from a tryAcquireShared
*/
private void setHeadAndPropagate(Node node, int propagate) {
Node h = head; // Record old head for check below
//4.thread0 进来后,将自己设置为头节点
setHead(node);
/*
* Try to signal next queued node if:
* Propagation was indicated by caller,
* or was recorded (as h.waitStatus either before
* or after setHead) by a previous operation
* (note: this uses sign-check of waitStatus because
* PROPAGATE status may transition to SIGNAL.)
* and
* The next node is waiting in shared mode,
* or we don't know, because it appears null
*
* The conservatism in both of these checks may cause
* unnecessary wake-ups, but only when there are multiple
* racing acquires/releases, so most need signals now or soon
* anyway.
*/
//5.propagate = 1
if (propagate > 0 || h == null || h.waitStatus < 0 ||
(h = head) == null || h.waitStatus < 0) {
//6.thread0 的下个节点为 thread1
Node s = node.next;
//7.thread1 为共享节点
if (s == null || s.isShared())
//8.唤醒后续节点 thread1
doReleaseShared();
}
}
前面提到为何不会依次唤醒线程0-9,原因就在 doReleaseShared()
考虑这样一种可能的情况:线程N唤醒 thread0 成功之后,thread0 调用 setHeadAndPropagate() 后也会调用 doReleaseShared(),这个时候唤醒 thread0 的线程N就会和 thread0 产生竞争。
private void doReleaseShared() {
for (;;) {
//4.线程N 执行到这里,thread0 也执行到这里,h为 thread0
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
//5.thread0 先一步将 thread0 的状态改为 -1,线程N就会CAS失败
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
//6.线程N 继续循环
continue; // loop to recheck cases
//1.线程N 唤醒 thread0
//thread1 被唤醒
unparkSuccessor(h);
}
//7.线程N 发现 thread0 的状态为0
else if (ws == 0 &&
//8.修改 thread0 的状态为 Node.PROPAGATE
//为啥要设置成 PROPAGATE 呢,个人觉得 CountDownLatch 不设置忽略掉这一步也没问题
//因为 CountDownLatch 的状态不可以复用,不会有线程再去修改状态 status,这里根本就不会有竞争
//但是 Semaphore 会用到
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
//2.thread0 调用了 setHeadAndPropagate 中的 setHead(node),thread0 变为头节点
//3.线程N 判断 h不为原来的节点,不会退出循环
//9.若被唤醒的 thread1 也 调用了 setHeadAndPropagate 中的 setHead(node)
//9.那么线程T、thread0 和thread1 都会再次执行 doReleaseShared()
if (h == head) // loop if head changed
break;
}
}
CyclicBarrier 和 CountDownLatch 类似,具体区别请看 CyclicBarrier源码分析。