CountDownLatch 中文有的叫做计数器,也有翻译为计数锁,其最大的作用不是为了加锁,而是通过计数达到等待的功能,主要有两种形式的等待:
- 让一组线程在全部启动完成之后,再一起执行(先启动的线程需要阻塞等待后启动的线程,直到一组线程全部都启动完成后,再一起执行)
- 主线程等待另外一组线程都执行完成之后,再继续执行
在看具体的源码前,先来看一个简单的使用示例
1.应用示例
模拟100米赛跑,10名选手已经准备就绪,只等裁判一声令下。当所有人都到达终点时,比赛结束。我们和容易想到主线程模拟裁判,开启十个子线程模拟运动员,但是会有以下两个问题:
- 10 个子线程必须等到主线程发出号令(控制台打印Game Start)之后再运行
- 主线程必须等待 10 个子线程运行结束之后再退出
下面就看看如何通过 2 个 CountDownLatch 来解决这俩问题。
public class CountDownLatchTest {
public static void main(String[] args) throws InterruptedException {
// 开始的倒数锁。count设置为1 是因为主线程只用-1
final CountDownLatch begin = new CountDownLatch(1);
// 结束的倒数锁,count设置为10 是因为有10个子线程都要-1
final CountDownLatch end = new CountDownLatch(10);
// 通过线程池创造出十名选手 (十个线程)
final ExecutorService exec = Executors.newFixedThreadPool(10);
// 让这十个线程运行起来
for (int index = 0; index < 10; index++) {
final int NO = index + 1; // 编号【1,10】
Runnable run = new Runnable() {
public void run() {
try {
// 如果当前计数为零(主线程已就绪),则此方法立即返回。
// 如果当前计数不为0(主线程还未调用countDown),等待。
begin.await();
Thread.sleep((long) (Math.random() * 10000));
System.out.println("No." + NO + " arrived");
} catch (InterruptedException e) {
} finally {
// 每个选手到达终点(线程执行完毕)时,end就减一
end.countDown();
}
}
};
exec.submit(run);
}
System.out.println("Game Start");
// begin减一,开始游戏
begin.countDown();
// 主线程会阻塞在这里,等待end变为0,即所有选手到达终点
end.await();
System.out.println("Game Over");
exec.shutdown();
}
}
控制台输出结果如下:
Game Start
No.9 arrived
No.6 arrived
No.8 arrived
No.7 arrived
No.10 arrived
No.1 arrived
No.5 arrived
No.4 arrived
No.2 arrived
No.3 arrived
Game Over
2.源码分析
CountDownLatch 的核心成员变量及主要构造函数如下:
public class CountDownLatch {
// 从 Sync 的继承关系就可以看出,CountDownLatch也是基于AQS框架实现的
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 构造函数,直接设置state=count
Sync(int count) {
setState(count);
}
// 调用AQS方法获取state
int getCount() {
return getState();
}
// 能否获取到共享锁。如果当前同步器的状态是 0 的话,表示可获得锁
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1; // state!=0,就拿锁失败
}
// 对 state 进行递减,直到 state 变成 0;state 递减为 0 时,返回 true,其余返回 false
protected boolean tryReleaseShared(int releases) {
// 自旋保证 CAS 一定可以成功
for (;;) {
int c = getState();
// state 已经是 0 了,直接返回 false
if (c == 0)
return false;
// state--
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
//-----------------------------构造函数------------------------------------
// 无空参构造,必须传入count,count相当于要等待的线程数
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
// 将count传给sync
this.sync = new Sync(count);
}
}
CountDownLatch 实质上就是利用了 AQS 的可重入性,并使用 AQS共享锁加锁模式
- 构造时传入 count,将 count 赋值给 AQS 的 state,相当于同一持锁线程多处重入
- await():将当前线程加入同步队列,休眠,等待 AQS 的 state=0 后被唤醒
- countDown():
state--
,当 state=0 时唤醒所有阻塞在 await() 的线程恢复运行
注:这里为什么不用独占锁?因为await的线程可能有多个,即state=0后所有 await() 的线程都需要都唤醒
2.1 await()
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// 带有超时时间的,最终都会转化成毫秒
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
acquireSharedInterruptibly()
这个方法是属于 AQS 的,CountDownLatch 的内部类 Sync 并没有进行实现/重写
// AQS
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 判断当前线程是否已经中断
if (Thread.interrupted())
throw new InterruptedException();
// 判断能否拿到锁
if (tryAcquireShared(arg) < 0)
// 拿不到就放入阻塞队列
doAcquireSharedInterruptibly(arg);
}
tryAcquireShared()
这个方法在上面已经列出过了,是内部类 Sync 的方法,当 state == 0 才能获取到锁
// CountDownLatch.Sync
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1; // state!=0,就拿锁失败
}
doAcquireSharedInterruptibly()
// AQS
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 将当前线程封装为node(共享模式),并加到同步队列队尾
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
// 自旋,保证所有被唤醒的线程都能依次恢复运行
for (;;) {
final Node p = node.predecessor();
// 当前node前进到队二 && tryAcquire成功(state减到0),就可以执行了
if (p == head) {
// 判断是否能拿到锁
int r = tryAcquireShared(arg);
if (r >= 0) {
// setHeadAndPropagate 会调用 doReleaseShared 去唤醒后续 Shared 节点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// 其余线程阻塞(最后也是在此处醒来)
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
2.2 countDown()
线程调用countDown方法后,会将AQS的state-1,若state=0了,就会唤醒所有阻塞在await处的线程。
public void countDown() {
sync.releaseShared(1);
}
releaseShared()
// AQS
public final boolean releaseShared(int arg) {
// 将state-1,若state=0了,表示当前线程释放锁成功
if (tryReleaseShared(arg)) {
// 唤醒后续节点
doReleaseShared();
return true;
}
return false;
}
tryReleaseShared()
// CountDownLatch.Sync
protected boolean tryReleaseShared(int releases) {
// 自旋保证 CAS 一定可以成功
for (;;) {
int c = getState();
// state 已经是 0 了,直接返回 false
if (c == 0)
return false;
// state--
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
doReleaseShared()
// AQS
private void doReleaseShared() {
// 自旋,保证所有线程正常的线程都能被唤醒
for (;;) {
Node h = head;
// 还没有到队尾,此时队列中至少有两个节点
if (h != null && h != tail) {
int ws = h.waitStatus;
// 如果头结点状态是 SIGNAL ,说明后续节点都需要唤醒
if (ws == Node.SIGNAL) {
// CAS 保证只有一个节点可以运行唤醒的操作
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
// 进行唤醒操作
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
// 退出自旋条件 h==head,一般出现于以下两种情况
// 第一种情况,头节点没有发生移动,结束。
// 第二种情况,因为此方法可以被两处调用,一次是获得锁的地方,一处是释放锁的地方,
// 加上共享锁的特性就是可以多个线程获得锁,也可以释放锁,这就导致头节点可能会发生变化,
// 如果头节点发生了变化,就继续循环,一直循环到头节点不变化时,结束循环。
if (h == head) // loop if head changed
break;
}
}
3.模拟退款示例
- 小明在淘宝上买了一个商品,觉得不好,把这个商品退掉(商品还没有发货,只退钱),我们叫做单商品退款,单商品退款在后台系统中运行时,整体耗时 30 毫秒。
- 双 11,小明在淘宝上买了 40 个商品,生成了同一个订单(实际可能会生成多个订单,为了方便描述,我们说成一个),第二天小明发现其中 30 个商品是自己冲动消费的,需要把 30 个商品一起退掉。
// 单商品退款,耗时 30 毫秒,退款成功返回 true,失败返回 false
@Slf4j
public class RefundDemo {
/**
* 根据商品 ID 进行退款
* @param itemId
* @return
*/
public boolean refundByItem(Long itemId) {
try {
// 线程沉睡 30 毫秒,模拟单个商品退款过程
Thread.sleep(30);
log.info("refund success,itemId is {}", itemId);
return true;
} catch (Exception e) {
log.error("refundByItemError,itemId is {}", itemId);
return false;
}
}
}
@Slf4j
public class BatchRefundDemo {
// 定义线程池
public static final ExecutorService EXECUTOR_SERVICE =
new ThreadPoolExecutor(10, 10, 0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(20));
@Test
public void batchRefund() throws InterruptedException {
// state 初始化为 30
CountDownLatch countDownLatch = new CountDownLatch(30);
RefundDemo refundDemo = new RefundDemo();
// 准备 30 个商品
List<Long> items = Lists.newArrayListWithCapacity(30);
for (int i = 0; i < 30; i++) {
items.add(Long.valueOf(i+""));
}
// 准备开始批量退款
List<Future> futures = Lists.newArrayListWithCapacity(30);
for (Long item : items) {
// 使用 Callable,因为我们需要等到返回值
Future<Boolean> future = EXECUTOR_SERVICE.submit(new Callable<Boolean>() {
@Override
public Boolean call() throws Exception {
boolean result = refundDemo.refundByItem(item);
// 每个子线程都会执行 countDown,使 state -1 ,但只有最后一个才能真的唤醒主线程
countDownLatch.countDown();
return result;
}
});
// 收集批量退款的结果
futures.add(future);
}
log.info("30 个商品已经在退款中");
// 使主线程阻塞,一直等待 30 个商品都退款完成,才能继续执行
countDownLatch.await();
log.info("30 个商品已经退款完成");
// 拿到所有结果进行分析
List<Boolean> result = futures.stream().map(fu-> {
try {
// get 的超时时间设置的是 1 毫秒,是为了说明此时所有的子线程都已经执行完成了
return (Boolean) fu.get(1,TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
} catch (TimeoutException e) {
e.printStackTrace();
}
return false;
}).collect(Collectors.toList());
// 打印结果统计
long success = result.stream().filter(r->r.equals(true)).count();
log.info("执行结果成功{},失败{}",success,result.size()-success);
}
}
通过以上代码,30 个商品退款完成之后,整体耗时大概在 200 毫秒左右。而通过 for 循环单商品进行退款,大概耗时在 1 秒左右,前后性能相差 5 倍左右,for 循环退款的代码如下:
long begin1 = System.currentTimeMillis();
for (Long item : items) {
refundDemo.refundByItem(item);
}
log.info("for 循环单个退款耗时{}",System.currentTimeMillis()-begin1);
一个面试题
如果一个线程需要等待一组线程全部执行完之后再继续执行,有什么好的办法么?是如何实现的?
答:CountDownLatch 就提供了这样的机制,比如一组线程有 5 个,只需要在初始化 CountDownLatch 时,给同步器的 state 赋值为 5,主线程执行 CountDownLatch.await() ,子线程都执行 CountDownLatch.countDown() 即可。