JDK从1.5版本开始提供CountDownLatch工具类,它能使一个县城等待其他线程各自完成工作后再执行,CountDownLatch内部是通过一个计数器实现的,计数器的初始值是批量任务初始线程的数量,每当一个线程完成任务后,计数器的值就会减1,当计数器的值为0时,唤醒所有被阻塞的线程。
一、CountDownLatch使用
首先,看下CountDownLatch的几个主要方法:
public class CountDownLatch {
...
// 阻塞当前线程,直到count值为0或者线程被中断
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// 阻塞当前线程,直到count值为0或者线程被中断或者超出等待时间
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// count值减1,如果count为0则唤醒所有被阻塞的线程
public void countDown() {
sync.releaseShared(1);
}
// 返回当前的count值
public long getCount() {
return sync.getCount();
}
...
}
CountDownLatch使用范例如下:
public class CountDownLatchTest {
public static void main(String[] args) throws InterruptedException {
int jobSize = 5;
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch endLatch = new CountDownLatch(jobSize);
ExecutorService exec = Executors.newCachedThreadPool();
for(int i=0; i < jobSize; i++) {
exec.submit(new Runnable() {
@Override
public void run() {
try {
startLatch.await();
} catch (Exception e) {
e.printStackTrace();
}
try {
Thread.sleep(1000);
System.out.println(Thread.currentThread().getName());
} catch (Exception e){
e.printStackTrace();
} finally {
endLatch.countDown();
}
}
});
}
long startTime = System.currentTimeMillis();
startLatch.countDown();
endLatch.await(2, TimeUnit.SECONDS);
long endTime = System.currentTimeMillis();
System.out.println("cost time : " + (endTime - startTime));
exec.shutdown();
}
}
---
pool-1-thread-5
pool-1-thread-4
pool-1-thread-3
pool-1-thread-1
pool-1-thread-2
cost time : 1004
二、CountDownLatch源码解析
CountDownLatch的成员变量、构造器和内部类实现如下,其结构非常简单,其唯一成员变量是一个同步器,继承自AbstractQueuedSynchronizer。
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
public void countDown() {
sync.releaseShared(1);
}
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
...
}
CountDownLatch调用countDown()方法时会调用的releaseShared方法将count值减1;调用await()方法时会调用acquireSharedInterruptibly,其核心逻辑如下:当CountDownLatch同步器count不等于0时,会调用doAcquireSharedInterruptibly方法,将当前线程封装为Node并阻塞,直到同步器count为0或被中断。
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0) // count!=0
doAcquireSharedInterruptibly(arg);
}
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
调用await(long timeout, TimeUnit unit)方法时会调用tryAcquireSharedNanos方法,在CountDownLatch的count不等于0时调用doAcquireSharedNanos方法,该方法和doAcquireSharedInterruptibly的逻辑基本一致,只是多了截止时间,进入方法后首先计算等待截止时间的判断,当前时间超过截止时间时直接返回false,值得注意的是这一句代码: nanosTimeout > spinForTimeoutThreshold
,其含义是当前时间距离截止时间小于spinForTimeoutThreshold时不阻塞线程,让线程在程序中自旋,自旋时间spinForTimeoutThreshold被默认是指为1000ns。
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}
private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (nanosTimeout <= 0L)
return false;
final long deadline = System.nanoTime() + nanosTimeout;
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return true;
}
}
nanosTimeout = deadline - System.nanoTime();
if (nanosTimeout <= 0L)
return false;
if (shouldParkAfterFailedAcquire(p, node) &&
nanosTimeout > spinForTimeoutThreshold)
LockSupport.parkNanos(this, nanosTimeout);
if (Thread.interrupted())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
static final long spinForTimeoutThreshold = 1000L;
具体AbstractQueuedSynchronizer共享式获取或释放解析见: AbstractQueuedSynchronizer共享式获取或释放。
三、CountDownLatch与Thread.join区别
在线程中调用Thread的join()方法,也可以实现阻塞当前线程的效果,Thread的join()方法调用后不断检查线程是否存活,如果存活则继续阻塞。join方法核心代码如下:
public final synchronized void join(long millis)
throws InterruptedException {
long base = System.currentTimeMillis();
long now = 0;
if (millis < 0) {
throw new IllegalArgumentException("timeout value is negative");
}
if (millis == 0) {
while (isAlive()) {
wait(0);
}
} else {
while (isAlive()) {
long delay = millis - now;
if (delay <= 0) {
break;
}
wait(delay);
now = System.currentTimeMillis() - base;
}
}
}
用join()方法实现类似CountDownLatch测试范例的逻辑,示例如下:
public class ThreadJoinTest {
public static void main(String[] args) {
long startTime = System.currentTimeMillis();
List<MyTask> taskList = new ArrayList<>();
for(int i=0; i < 5; i++) {
MyTask task = new MyTask();
taskList.add(task);
task.start();
}
for(MyTask task : taskList) {
try {
task.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
long endTime = System.currentTimeMillis();
System.out.println("cost time : " + (endTime - startTime));
}
public static class MyTask extends Thread {
@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName());
}
}
}
---
Thread-0
Thread-2
Thread-1
Thread-3
Thread-4
cost time : 1005
可见,Thread的join()方法也可以实现类似逻辑,那CountDownLatch与Thread.join()方法的区别是什么?
区别:Thread.join()方法依赖于线程的存活情况,等所线程执行完毕时才能往下执行,而CountDownLatch提供计数器的功能,更加灵活,只需监测计数器count值为0就可继续往下执行,与线程执行情况可以解耦。