1.例子
public void testCountDownLatch() throws InterruptedException {
CountDownLatch contDownLatch = new CountDownLatch( 10);
for(int i=0;i<10;i++){
String name = "这是第"+i+"个线程";
BiFunction<CountDownLatch,String,CountDownLatchThead> fn = CountDownLatchThead::new;
new Thread(fn.apply(contDownLatch,name)).start();
}
contDownLatch.await();
System.out.println("all done");
}
class CountDownLatchThead implements Runnable {
private CountDownLatch contDownLatch;
private String name;
CountDownLatchThead(CountDownLatch contDownLatch,String name){
this.contDownLatch = contDownLatch;
this.name =name;
}
@Override
public void run() {
synchronized (this){
try{
if(name.equalsIgnoreCase("这是第9个线程")){
Thread.sleep(100);
}
System.out.println(name+"执行完成了 num=");
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
contDownLatch.countDown();
}
}
}
}
2.概念
countDownLatch这个类使一个线程等待其他线程各自执行完毕后再执行。
是通过一个计数器来实现的,计数器的初始值是线程的数量。每当一个线程执行完毕后,计数器的值就-1,当计数器的值为0时,表示所有线程都执行完毕,然后在闭锁上等待的线程就可以恢复工作了。
3.源码
这个类中只提供了一个构造器
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
可以看到创建CountDownLatch对象就是创建Sync对象,先看下Sync对象的代码
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
可以看到其中有2个方法,尝试释放资源tryReleaseShared(arg)和尝试直接获取tryAcquireShared(arg)方法,后面我们对这2个方法进行解析。
在CountDownLatch中有3个方法是最重要的
//调用await()方法的线程会被挂起,它会等待直到count值为0才继续执行
public void await() throws InterruptedException { };
//和await()类似,只不过等待一定的时间后count值还没变为0的话就会继续执行
public boolean await(long timeout, TimeUnit unit) throws InterruptedException { };
//将count值减1
public void countDown() { };
1.await()
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
代码很简单,就一句话(注意acquireSharedInterruptibly()方法是抽象类:AbstractQueuedSynchronizer的一个方法,我们上面提到的Sync继承了它),我们跟踪源码,继续往下看:
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
源码也是非常简单的,首先判断了一下,当前线程是否有被中断,如果没有的话,就调用tryAcquireShared(int acquires)方法,判断一下当前线程是否还需要“阻塞”。其实这里调用的tryAcquireShared方法,就是我们上面提到的java.util.concurrent.CountDownLatch.Sync.tryAcquireShared(int)这个方法。
当然,在一开始我们没有调用过countDownLatch.countDown()方法时,这里tryAcquireShared方法肯定是会返回-1的,因为会进入到doAcquireSharedInterruptibly方法。
2. doAcquireSharedInterruptibly(int arg)
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);//添加节点
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();//获取前置节点
if (p == head) {//如果是头节点
int r = tryAcquireShared(arg);//尝试获取资源
if (r >= 0) {
setHeadAndPropagate(node, r);//唤醒下一个节点
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())//尝试清理waitStatus>0的node或中断线程
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
步骤
1.创建countDownLatch对象并设置栏的数量
2.判断线程是否中断,中断了抛出异常
3.尝前计数器的值(AQS中的status),是否为0了,如果为0的话返回1
4.若不为0,开启一个循环调用
5.新建node,等待的链表
6.自旋判断计数器的值是否为0,若为0,唤醒下一个节点并跳出循环
7.若不为0,挂起线程或中断线程
countDown()
public void countDown() {
sync.releaseShared(1);
}
代码简单就是调用sync中的releaseShared方法 (这个也是AQS中的方法)
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {//对计数器数值-1返回是否为0
doReleaseShared(); //自旋阻塞的node
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();//获取计数器值
if (c == 0)
return false;
int nextc = c-1; //计数器值-1
if (compareAndSetState(c, nextc))//修改计数器中(CAS)
return nextc == 0;
}
}
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
在自旋的阶段,每一次循环的过程都是首先获得头结点,如果头结点不为空且不为尾结点(阻塞队列里面只有一个结点),那么先获得该节点的状态,如果是SIGNAL的状态,则代表它需要有后继结点去唤醒,首先将其的状态变为0,因为是要释放资源了,它也不需要做什么了,所以转变为初始状态,然后去唤醒后继结点unparkSuccessor(h),如果结点状态一开始就是0,那么就给他转换成PROPAGATE状态,保证在后续获取资源的时候,还能够向后面传播
步骤
1.计数器减1
2.若计数器不为0,则自旋
3.首先获得头结点
4.如果头结点不为空且不为尾结点(阻塞队列里面只有一个结点),那么先获得该节点的状态,如果是SIGNAL的状态,则代表它需要有后继结点去唤醒