countDownLatch可以理解为一个倒序计数器,在初始化时给计数器赋值,每次操作计数器减1,等计数器为0后才可继续向下执行。通过该特性可以在特定任务执行前使用多线程处理,例如多线程读取文件,然后对文件拼接整合。
使用案例
以下代码演示初始CountDownLatch为2的执行流程,当new CountDownLatch(3)时程序会一直阻塞。
public static void main(String[] args) throws InterruptedException {
final CountDownLatch countDownLatch = new CountDownLatch(2);
// 为3时一直等待
// final CountDownLatch countDownLatch = new CountDownLatch(3);
new Thread("thread-1") {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + "读取文件需要5秒");
try {
sleep(5000);
} catch (InterruptedException e) {
e.printStackTrace();
}
countDownLatch.countDown();
System.out.println(Thread.currentThread().getName() + "读取文件结束");
}
}.start();
new Thread("thread-2") {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + "读取文件");
countDownLatch.countDown();
System.out.println(Thread.currentThread().getName() + "读取文件结束");
}
}.start();
countDownLatch.await();
System.out.println("所有任务执行完");
// countDownLatch.await设置超时时间,超过此主线程直接执行
// countDownLatch.await(1, TimeUnit.SECONDS);
// System.out.println("不等了,主线程执行");
}
}
源码解析
主要原理为countDownLatch内部使用AQS锁,将AQS中state字段作为计数器使用,每次调用countDown()方法对state字段减1。当state为0时唤醒所有线程。
初始化
设置计数器起始值
/**
* new CountDownLatch()时对AQS中state进行赋值
*
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
Sync(int count) {
setState(count);
}
阻塞线程
调用countDownLatch.wait()的时候,会创建一个节点,加入到AQS阻塞队列,并同时把当前线程挂起
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
/**
* 判断计数器是否计数完毕,没完毕则将线程放入阻塞队列
*/
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
/**
* 构建阻塞队列的双向链表,挂起当前线程
*/
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
计数器递减
对计数器进行递减操作,为0时唤醒所有线程
/**
* 对state减1
*/
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) { // CountDownLatch的Sync内部类中对计数器减1
doReleaseShared(); // 唤醒所有阻塞队列中线程
return true;
}
return false;
}