23.CountDownLatch的应用和原理

1.基本使用

我们先看一下如何使用CountDownLatch。

public class CountDownLatchExample {
    public static void main(String[] args) throws InterruptedException {
        CountDownLatch countDownLatch=new CountDownLatch(2);
        new Thread(new RelationService(countDownLatch)).start();
        new Thread(new RelationService(countDownLatch)).start();
        countDownLatch.await();
    }
    static class RelationService implements Runnable{
        private CountDownLatch countDownLatch;
        public RelationService(CountDownLatch countDownLatch){
            this.countDownLatch=countDownLatch;
        }

        @Override
        public void run(){
            //doSomething
            System.out.println(Thread.currentThread().getName()+"->done");
            countDownLatch.countDown(); //当前线程执行结束后进行计数器递减
        }
    }
}

上面的代码构建了一个倒计时为2的countDownLatch实例。定义两个线程分别执行RelationService线程,在线程中调用countDownLatch.countDown()方法,表示对倒计时进行递减,其实也可以认为当前线程的某个任务执行完毕。最后在main()方法中调用countDownLatch.await()进行阻塞,当计数器为0时被唤醒。

该类的使用类似Thread.join(),但是比其更加灵活。我们通过一个图示进一步看一下上面的执行过程:

 这里其实就是在main()方法里设置了一个计数器,当计数器归零时就触发所有await()阻塞的线程。

CountDownLatch到底有啥用呢?我们假设一个场景,当我们启动一个应用时,希望能够检查依赖的第三方服务是否运行正常,一旦依赖的服务没有启动,那么当前应用在启动是就需要等待。

首先定义一个抽象的健康检查类来检测服务的启动状态:

public abstract class BaseHealthChecker implements Runnable {

    private CountDownLatch countDownLatch;
    private String serviceName;
    private boolean serviceUp;

    public abstract void verifyService();

    public String getServiceName() {
        return serviceName;
    }
    public boolean isServiceUp() {
        return serviceUp;
    }
    
    public BaseHealthChecker(CountDownLatch countDownLatch, String serviceName) {
        this.countDownLatch = countDownLatch;
        this.serviceName = serviceName;
    }

    @Override
    public void run() {
        try {
            verifyService();
            serviceUp = true;
        } catch (Throwable t) {
            t.printStackTrace();
            serviceUp = false;
        } finally {
            if (countDownLatch != null) {
                countDownLatch.countDown();
            }
        }
    }
}

然后定义一个缓存的健康检查类:

public class CacheHealthChecker extends BaseHealthChecker {
    public CacheHealthChecker(CountDownLatch countDownLatch) {
        super(countDownLatch, "cacheHealthChecker");
    }

    @Override
    public void verifyService() {
        System.out.println("checking" + this.getServiceName());
        try {
            Thread.sleep(3000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(this.getServiceName() + "is up");
    }

}

再定义一个数据库的监控检查类:

public class DatabaseHealthChecker extends BaseHealthChecker {
    public DatabaseHealthChecker(CountDownLatch countDownLatch) {
        super(countDownLatch, "databaseHealthChecker");
    }

    @Override
    public void verifyService() {
        System.out.println("checking" + this.getServiceName());
        try {
            Thread.sleep(3000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(this.getServiceName() + "is up");
    }

}

之后定义一个监控程序的类:

public class CountDownApp {
    //    检查所有要预检查的服务列表
    private static List<BaseHealthChecker> services = new ArrayList<>();
    private static CountDownLatch latch = new CountDownLatch(2);
    private final static CountDownApp instance = new CountDownApp();

    static {
        services.add(new CacheHealthChecker(latch));
        services.add(new DatabaseHealthChecker(latch));
    }

    private CountDownApp() {
    }

    public static CountDownApp getInstance() {
        return instance;
    }

    public static boolean checkServices() throws InterruptedException {
//        创建线程调度器
        Executor executor = Executors.newFixedThreadPool(services.size());
        for (final BaseHealthChecker v : services) {
            executor.execute(v);
        }
//进行定时器等待,直到检查所有服务都启动完成
        latch.await();
        for (final BaseHealthChecker v : services) {
            if (!v.isServiceUp()) {
                return false;
            }
        }
        return true;

    }
}

最后定义一个测试类:

public class CountDownTest {
    public static void main(String[] args) throws InterruptedException {
        boolean result = false;
        result = CountDownApp.checkServices();
        System.out.println("所有服务已经启动:" + result);
    }
}

这样我们就可以分别检查缓存服务器和数据库服务器的状态,都启动之后就会打印出最终的:

checkingdatabaseHealthChecker
checkingcacheHealthChecker
databaseHealthCheckeris up
cacheHealthCheckeris up
所有服务已经启动:true

2 CountDown的实现原理

根据前面的分析,我们大致能推测到CountDownLatch也应该使用了AQS的共享锁机制,因为让多个处于await()阻塞的多线程同时被唤醒,使用AQS的共享锁正好能实现,而看代码,我们发现事实也确实如此。

await也是CountDownLatch的入口,根据具体用法,可以阻塞一个或者多个线程。

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

doAcquireSharedInterruptibly()是AQS中共享锁的获取方法,而且根据名字可以判断这里是允许被中断的。不过在acquireSharedInterruptibly()中,先通过tryAcquireShared()方法判断返回结果。

  • 如果小于0,说明state字段的值不为0,需要调用doAcquireSharedInterruptibly()方法进行阻塞。

  • 如果大于或者等于0,则说明state已经为0,可以直接返回不需要阻塞。

接下来我们就详细看一下acquireSharedInterruptibly()方法做的事情。

既然state代表的计数器不为0,那么当前线程必然需要等待,所以doAcquireSharedInterruptibly()方法基本上可以猜测到是用来构建CLH队列并阻塞线程的,代码如下:

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    //创建一个共享模式的结点添加到队列中
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
            //根据判断结果尝试获取锁
                int r = tryAcquireShared(arg);
                //表示获取了执行权限,这时因为state!=0,所以不会执行这段代码
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            //阻塞线程
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

我们可以看到这里调用addWaiter()方法构建一个双向链表,这就是AQS中的排他锁的实现 ,注意Node的mode是shared模式。然后利用tryAcquireShared()方法并通过for(;;)自旋循环抢占锁,这时候会返回一个状态r。判断r的值,如果r大于等于0,表示当前线程得到了执行权限,则调用setHeadAndPropagate()方法唤醒当前的线程。最后是shouldParkAfterFailedAcquire()方法和AQS排他锁中的方法是一样的,如果没抢占到锁,则判断是否需要挂起来。

这个可以看到,与AQS的排他锁整体实现基本是相同的,共享锁抢占到执行权限基本上就是判断state满足某个固定的值,并且允许多个线程同时获得执行权限,这是共享锁的特征。另外,获得执行权限后调用setHeadAndPropagate()方法不仅仅重置head结点,而且需要进行唤醒的传播。

接下来,我们通过一个示例来看一下CountDownLatch的基本过程:

假设有两个线程ThreadA和ThreadB,分别调用了await()方法,此时由于state锁表示的计数器不为0,所以添加到AQS的CLH队列中,如下图所示,与排他锁最大的区别是结点类型是SHARED。

3 countDown过程

在调用await()方法后,ThreadA和ThreadB两个线程会加入到CLH队列中并阻塞线程,他们需要等到一个倒计时信号,也就是countDown()方法对state进行递减,直到state为0,则唤醒处于同步队列中被阻塞的线程,代码如下:

public void countDown() {
    sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
//递减共享锁信号
    if (tryReleaseShared(arg)) {
    //唤醒线程
        doReleaseShared();
        return true;
    }
    return false;
}

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

在tryReleaseShared()方法中,只有当state减为0的时候,tryReleaseShared()才会返回true,否则只是执行简单的state=state-1。如果state=0,则调用doReleaseShared()方法唤醒同步队列中的线程。

3.1 doReleaseShared()方法

private void doReleaseShared() {
    for (;;) {
    //每次循环时head都有变化,因为调用unparkSuccessor()方法会导致head结点发生变化
        Node h = head;
        //AQS队列中存在多个阻塞的结点
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            //如果结点的状态为SIGNAL,则表示可以被唤醒
            if (ws == Node.SIGNAL) {
            //如果此时失败说明有当前结点的线程状态被修改了,不需要被唤醒。继续下一次循环即可
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            //ws == 0 是初始状态,则修改该结点状态为PROPAGATE
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

这个方法本身要做的是,从AQS的同步队列中唤醒head结点的下一个结点,所以只需要满足两个条件:

  • h != null && h != tail ,判断队列中是否有处于等待状态的线程。

  • h.waitStatus==Node.SIGNAL,表示结点状态正常。

满足以上条件就会调用unparkSuccessor()方法唤醒线程。

3.2 unparkSuccessor()方法

unparkSuccessor()方法主要用来唤醒head结点的下一个结点,代码如下:

private void unparkSuccessor(Node node) {
    int ws = node.waitStatus;
    if (ws < 0)
        compareAndSetWaitStatus(node, ws, 0);

    Node s = node.next;
    if (s == null || s.waitStatus > 0) {
        s = null;
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }
    if (s != null)
        LockSupport.unpark(s.thread);//唤醒指定结点
}

上述代码主要有两个逻辑,作为设计者来说需要考虑到:

  • 如果head结点的下一个结点s==null或者结点状态为取消,则不需要再唤醒。

  • 通过for (Node t = tail; t != null && t != node; t = t.prev)循环从tail尾部结点往head结点方向遍历找到距离head最近的一个有效结点,这与上一章重入锁的原因是一致的,最后对该结点通过LockSupport.unpark()方法进行唤醒。

4 线程被唤醒之后的工作

当处于CLH队列的head.next结点被唤醒后,继续从原本被阻塞的地方开始执行,因此我们回到doAcquireInterruptibly()方法中,代码如下:

 private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {//被唤醒的线程进入下一次循环继续判断
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; //把当前结点从AQS队列中移除
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

被唤醒的线程进入下一次循环,此时满足r>=0的条件(当r>=0时,说明state的值已经变成了0),因此执行setHeadAndPropagate(node, r)方法。

我们再来看一下setHeadAndPropagate()方法:

private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; // Record old head for check below
    setHead(node);
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        Node s = node.next;
        if (s == null || s.isShared())
            doReleaseShared();
    }
}

这段代码看似简单,但是实际处理的场景挺多。首先是调用setHead(node)方法将当前被唤醒的线程所在结点设置成head结点。当满足如下条件时继续调用doReleaseShared()方法唤醒后续的线程:

  • 情况1:propagate>0,表示当前是共享锁,需要进行唤醒传递。

  • 情况2:h == null和(h = head) == null ,这些条件是避免空指针的写法,这种情况可能出现的场景是原来的head结点正好从链表中断开,在临界的情况下满足该条件可能会出现。

  • 情况3:h.waitStatus < 0,可能为0,也可能是-1,propagate。

  • 情况4:s.isShared(),判断当前结点是否为共享模式。

分析到这里可以发现,doReleaseShared()方法调用了如下的两个方法:

  • 当计数器归零时调用countDown()方法。

  • 被阻塞的线程被唤醒之后,调用setHeadAndPropagate()

小结

当ThreadC调用countDown()方法之后,如果state=0,则会唤醒处于AQS队列中的线程,然后调用setHeadAndPropagate()方法,实现锁释放的传递,从而唤醒所有阻塞再await()方法中的线程。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

纵横千里,捭阖四方

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值