本篇文章主要介绍CountDownLatch 相关知识。通过阅读你会有如下收获:
1. 什么是 CountDownLatch ?
2. 如何通过CountDownLatch实现一个程序启动检查服务?
3. CountDownLatch源码如何实现的?
一.什么是CountDownLatch
CountDownLatch是一种同步辅助工具,它允许一个或多个线程等待,直到在其他线程中执行的一组操作完成为止。
其工作原理是使用线程数初始化计数器,每次线程完成执行时,计数器都会递减。当count达到零时,表示所有线程已完成其执行,并且等待的主线程将恢复执行。
CountDownLatch的伪代码可以这样编写:
//主线程启动
//为N个线程创建CountDownLatch
//创建并启动N个线程
//主线程在锁存器上等待
// N个线程完成任务返回
//主线程恢复执行
二. 实战
1. 需求
实现一个应用程序启动检查功能,启动N个线程,分别同时检查数据库,网络等外部系统,并报告检查结果给正在等待的启动类。
2. 分析
这里可以使用CountDownLatch模拟一个应用程序启动类,该类启动了N个线程,每个线程检查完对应的程序后调用countdown() 方法。主线程通过 await() 方法来等待检查结果。一旦验证并检查了所有服务,启动就会继续。
3 代码实现
完整代码Github地址:使用countdownlaunch实现程序启动健康检查
首先定义一个抽象基类,AbstractHealthCheck:实现 Runnable 接口,负责所有特定的外部服务健康检查的基类。
import java.util.concurrent.CountDownLatch;
/**
* @Description TODO
* @Author tr.wang
* @Date 2019/11/28 19:34
* @Version 1.0
*/
public abstract class AbstractHealthCheck implements Runnable{
private CountDownLatch latch;
private String serviceName;
private boolean serviceUp;
public AbstractHealthCheck(String serviceName, CountDownLatch latch)
{
super();
this.latch = latch;
this.serviceName = serviceName;
this.serviceUp = false;
}
@Override
public void run() {
try {
verifyService();
serviceUp = true;
} catch (Throwable t) {
t.printStackTrace(System.err);
serviceUp = false;
} finally {
if(latch != null) {
latch.countDown();
}
}
}
public String getServiceName() {
return serviceName;
}
public boolean isServiceUp() {
return serviceUp;
}
public abstract void verifyService();
}
- 以下三个类都继承自 AbstractHealthCheck,使用模板的设计方式,引用 CountDownLatch 实例,实现各自的 verifyService() 方法。
CacheHealthChecker
import java.util.concurrent.CountDownLatch;
/**
* @Description TODO
* @Author tr.wang
* @Date 2019/11/28 19:50
* @Version 1.0
*/
public class CacheHealthChecker extends AbstractHealthCheck
{
public CacheHealthChecker (CountDownLatch latch)
{
super("Cache Service", latch);
}
@Override
public void verifyService()
{
System.out.println("Checking " + this.getServiceName());
try {
Thread.sleep(5000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(this.getServiceName() + " is UP");
}
}
DatabaseHealthChecker
import java.util.concurrent.CountDownLatch;
/**
* @Description TODO
* @Author tr.wang
* @Date 2019/11/28 19:50
* @Version 1.0
*/
public class DatabaseHealthChecker extends AbstractHealthCheck
{
public DatabaseHealthChecker (CountDownLatch latch)
{
super("Database Service", latch);
}
@Override
public void verifyService()
{
System.out.println("Checking " + this.getServiceName());
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(this.getServiceName() + " is UP");
}
}
NetworkHealthChecker
import java.util.concurrent.CountDownLatch;
/**
* @Description TODO
* @Author tr.wang
* @Date 2019/11/28 19:47
* @Version 1.0
*/
public class NetworkHealthChecker extends AbstractHealthCheck
{
public NetworkHealthChecker (CountDownLatch latch)
{
super("Network Service", latch);
}
@Override
public void verifyService()
{
System.out.println("Checking " + this.getServiceName());
try
{
Thread.sleep(7000);
}
catch (InterruptedException e)
{
e.printStackTrace();
}
System.out.println(this.getServiceName() + " is UP");
}
}
测试类 CountDownLanchTest
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
/**
* @Description TODO
* @Author tr.wang
* @Date 2019/11/29 10:10
* @Version 1.0
*/
public class CountDownLanchTest {
private static List<AbstractHealthCheck> services;
private static CountDownLatch latch;
public static boolean checkExternalServices() throws Exception
{
latch = new CountDownLatch(3);
services = new ArrayList<AbstractHealthCheck>();
services.add(new NetworkHealthChecker(latch));
services.add(new CacheHealthChecker(latch));
services.add(new DatabaseHealthChecker(latch));
Executor executor = Executors.newFixedThreadPool(services.size());
for(final AbstractHealthCheck v : services)
{
executor.execute(v);
}
latch.await();
for(final AbstractHealthCheck v : services)
{
if( ! v.isServiceUp())
{
return false;
}
}
return true;
}
public static void main(String[] args)
{
boolean result = false;
try {
result = checkExternalServices();
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("External services validation completed !! Result was :: "+ result);
}
}
输出为:
Checking Network Service
Checking Cache Service
Checking Database Service
Database Service is UP
Cache Service is UP
Network Service is UP
External services validation completed !! Result was :: true
三.CountDownLatch源码分析
从上面的例子可以看出
- 初始化时,设置计数(count)值,也就是闭锁需要等待的线程数。
- 主线程必须在启动其他线程后立即调用 CountDownLatch.await() 方法,这样主线程的操作就会在这个方法上阻塞,直到其他线程完成各自的任务为止。
- 其他 N 个线程必须引用闭锁对象,因为它们如果完成了任务需要通过 CountDownLatch.countDown() 方法来通知CountDownLatch实例,每次调用计数减少 1。当所有 N 个线程都调用了这个方法时,计数将达到 0,主线程可以在 await() 方法之后继续执行。
注意:该同步组件实现过程中,只需在syn静态内部类中重写的tryAcquireShared () / tryReleaseShared () 方法,其他均为AQS抽象类中的现有实现。
1. 构造函数
我们按照这个顺序首先看其构造函数,构造函数将计数值(count)传递给 Sync,并且设置了 state。
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
......
Sync(int count) {
setState(count);
}
2. await 方法
调用AQS 的 acquireSharedInterruptibly()。该方法首先判断是否被中断,中断就抛出异常。接下来调用 tryAcquireShared(arg)尝试获取共享锁。返回 1 代表获取成功,返回 -1 代表获取失败。如果获取失败,需要调用 doAcquireSharedInterruptibly();进入AQS中的wait队列中,等待获取锁,线程在此阻塞。
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
......
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
//线程被中断则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
//查看当前计数器千直是否为 0 , 为 0 直接返回, 否则进入AQS的队列等待
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
......
//state 状态变量,state 的值代表着待达到条件的线程数,
//比如初始化为 5,表示待达到条件的线程数为 5,每次调用 countDown() 函数都会减 1。
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
由如上代码可知, 该方法的特点是线程获取资源时可以被中断, 并且获取的资源是共享资源。acquireSharedInterruptibly 首先判断当前线程是否己被中断, 若是则抛出异常,否则调用sync 实现的 tryAcquireShared 方法查看当前状态值( 计数器值)是否为 0 , 是则当前线程的await() 方法直接返回, 否则调用 AQS 的 doAcquireSharedlnterruptibly 方法让当前线程阻塞。另外可以看到,这里tryAcquireShared 传递的 arg 参数没有被用到, 调用 try AcquireShared 的方法仅仅是为了检查当前状态值是不是为0 , 并没有调用CAS 让当前状态值减1 。
3. countDown方法
countDown 操作实际就是释放锁的操作,每调用一次,计数值减少 1。
public void countDown() {
sync.releaseShared(1);
}
......
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
//唤醒被await阻塞的线程
doReleaseShared();
return true;
}
return false;
}
......
/**
* 自旋方式加上 CAS 的方式保证 state 的减 1 操作,
* 当计数值等于 0,代表所有子线程都执行完毕
*/
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
//如果当前状态值为 0 直接返回( 1 )
if (c == 0)
return false;
//使用CAS让计数器减1 ( 2 )
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
如上代码首先获取当前状态值(计数器值) 。
代码 ( 1 ) 判断如果当前状态值为 0 则直接返回 false ,从而 countDown ( )方法直接返回;
否则执行代码 (2) 使用CAS 将计数器值减 1, CAS 失败则循环重试,否则如果当前计数器值为 0 则返回 true ,返回 true 说明是最后一个线程调用的 countdown 方法,那么该线程除了让计数器值减 1 外,还需要唤醒因调用 CountDownLatch 的 await 方法而被阻塞的线程,具体是调用 AQS 的 doReleaseShared方法来激活阻塞的线程。这里代码 ( 1 ) 貌似是多余的,其实不然,之所以添加代码 ( 1 )是为了防止当计数器值为 0 后,其他线程又调用了 countDown 方法,如果没有代码 ( 1 ) 状态值就可能会变成负数。
四. 小结
CountDownLatch是使用 AQS 实现的。使用 AQS 的状态变量来存放计数器的值。首先在初始化 CountDownLatch 时设置状态值(计数器值),当多个线程调用 countdown 方法时实际是原子性递减AQS 的状态值。当线程调用 await 方法后当前线程会被放入 AQS 的阻塞队列等待计数器为 0 再返回。其他线程调用 countdown 方法让计数器值递减 1,当计数器值变为 0 时, 当前线程还要调用 AQS 的doReleaseShared 方法来激活由于调用 await() 方法而被阻塞的线程。
下一篇:【多线程实战 三】-通过CyclicBarrier来优化对账程序的执行效率
参考资料
Java并发编程实战
Java并发编程的艺术