首先看内部结构
public class CountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
}
自定义内部类 Sync继承自AQS,在其中实现了tryAcquireShared,tryReleaseShared方法。
使用过程
new CountDownLatch(1);
首先进入构造方法
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
Sync(int count) {
setState(count);
}
可以看到,构造方法其实是将state设定为参数值。
使用await()方法
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
调用AQS方法
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
从AQS再调用自定义实现中的tryAquireShare(arg)方法
CountDownLatch.class
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
很简单,就是判断state值是否为0。为0时则可以继续执行业务流程。否则进入以下方法
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
该方法将目前线程入同步队列,并挂起。
countDown方法
public void countDown() {
sync.releaseShared(1);
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
将state减一
AQS.class
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
通过调用子类tryReleaseShared方法,实现以下功能:将state值减一,不过state不为0,则什么都不做。如果为0,则激活同步队列中挂起的头节点,即await等待的线程。
总的来说,CountDownLock利用了AQS state作为计数,并利用了 acquireShare和releaseShare方法,await使用了AQS的aquireShare,会调用子类tryacquireShare进行判断,如果值为负,则将线程放入同步队列挂起,如果为正,则AQS中什么也不做,即该线程获取到执行权限。而CountDownLock正是实现了tryRelease方法,在其中判断state是否为0,不为0 则返回负值。从而将本线程入队。
count()方法,将使用tryReleaseShare 将state字段减一,如果为0,则调用AQS中doReleaseShare方法,将同步队列的节点唤醒。
使用示例
public class CountDownLatchTest {
static class MyThread extends Thread{
CountDownLatch countDownLatch;
public MyThread(CountDownLatch countDownLatch) {
this.countDownLatch=countDownLatch;
}
@Override
public void run() {
try {
Thread.sleep(1000);
System.out.println("thread sleep done");
countDownLatch.countDown();
}catch (InterruptedException e){
e.printStackTrace();
}
}
}
public static void main(String[] args) {
int count=5;
CountDownLatch countDownLatch=new CountDownLatch(count);
for (int i=0;i<5;i++){
new MyThread(countDownLatch).start();
}
try {
countDownLatch.await();
System.out.println("main thread run");
}catch (InterruptedException e){
e.printStackTrace();
}
}
}
运行结果
thread sleep done
thread sleep done
thread sleep done
thread sleep done
thread sleep done
main thread run
以上