一个线程await后被阻塞,直到n个线程countDown后才会被唤醒。
package java.util.concurrent;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
/**
* 尝试获取资源。如果state为0则能够获取
*/
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
/**
* 尝试释放资源,state为0时代表完全释放资源
*/
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
// state为0时不可以再释放了
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
/**
* 初始化
* @param count 等待执行的线程数
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
/**
* 可中断的获取锁,state不为0则被阻塞
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
/**
* 可中断地获取锁,超时放弃
* @param timeout 相对等待时间
* @param unit 单位
* @return 是否获取成功
* @throws InterruptedException 中断异常
*/
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
/**
* state减少
*/
public void countDown() {
sync.releaseShared(1);
}
/**
* 获取state
*/
public long getCount() {
return sync.getCount();
}
}