开发中最常见的场景,在主线程中开启多线程并执行任务,主线程需要等待所有子线程执行完毕后再进行处理的场景。
在CountDownLatch出现之前,一般都是使用线程的join()方法来实现,但是join不够灵活,不能够满足不同场景的需要,所以JDK后来提供了CountDownLatch,用于同步。
一、CountDownLatch使用实例
public class JoinCountDownLatch
{
// 创建一个CountDownLatch实例
private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main(String[] args) throws InterruptedException
{
Thread threadOne = new Thread(new Runnable()
{
@Override
public void run()
{
try
{
Thread.sleep(1000);
System.out.println("child threadOne over!");
}
catch (InterruptedException e)
{
e.printStackTrace();
}
finally
{
countDownLatch.countDown();
}
}
});
Thread threadTwo = new Thread(new Runnable()
{
@Override
public void run()
{
try
{
Thread.sleep(1000);
System.out.println("child threadTwo over!");
}
catch (InterruptedException e)
{
e.printStackTrace();
}
finally
{
countDownLatch.countDown();
}
}
});
// 启动子线程
threadOne.start();
threadTwo.start();
System.out.println("wait all child thread over!");
// 等待子线程执行完毕,返回
countDownLatch.await();
System.out.println("all child thread over!");
}
}
运行输出结果:
wait all child thread over!
child threadTwo over!
child threadOne over!
all child thread over!
创建了CountDownLatch实例,因为有两个子线程,所以构造函数参数传递为2,主线程调用countDownLatch.await()方法后会被阻塞。
子线程执行完毕后调用countDownLatch.countDown()方法让countDownLatch内部的计数器减一,等所有子线程执行完毕调用countDown()后计数器会变为0,这时候主线程的await()才会返回。
二、CountDownLatch实现原理
CountDownLatch内部有个计数器,并且这个计数器是递减的。
public class CountDownLatch
{
/**
* AQS
*/
private static final class Sync extends AbstractQueuedSynchronizer
{
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count)
{
setState(count);
}
int getCount()
{
return getState();
}
protected int tryAcquireShared(int acquires)
{
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases)
{
// Decrement count; signal when transition to zero
for (; ; )
{
int c = getState();
if (c == 0)
{
return false;
}
int nextc = c - 1;
if (compareAndSetState(c, nextc))
{
return nextc == 0;
}
}
}
}
//基于AQS的同步器
private final Sync sync;
//等待计数为0
public void await() throws InterruptedException
{
sync.acquireSharedInterruptibly(1);
}
//计数递减1
public void countDown()
{
sync.releaseShared(1);
}
}
A、构造方法
public CountDownLatch(int count)
{
if (count < 0)
{
throw new IllegalArgumentException("count < 0");
}
this.sync = new Sync(count);
}
Sync(int count)
{
setState(count);
}
CountDownLatch内部是使用AQS实现的,通过构造函数初始化了计数器的值,可知实际上是把计数器的值赋值给了AQS的状态值state,也即用AQS的状态值来表示计数器值。
B、void await()方法
当线程调用了CountDownLatch对象的await方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:
- 当所有线程都调用了CountDownLatch对象的countDown方法后,也即计时器值为0时;
- 其它线程调用了当前线程的interrupt()方法中断了当前线程,当前线程会抛出InterruptedException异常后返回。
await()方法内部:
// AQS的获取共享资源时候可被中断的方法
public final void acquireSharedInterruptibly(int arg) throws InterruptedException
{
// 如果线程被中断则抛异常
if (Thread.interrupted())
{
throw new InterruptedException();
}
// 尝试看当前是否计数值为0,为0则直接返回,否者进入AQS的队列等待
if (tryAcquireShared(arg) < 0)
{
doAcquireSharedInterruptibly(arg);
}
}
// sync类实现的AQS的接口
protected int tryAcquireShared(int acquires)
{
return (getState() == 0) ? 1 : -1;
}
// CountDownLatch的await()方法
public void await() throws InterruptedException
{
sync.acquireSharedInterruptibly(1);
}
await()方法委托sync调用了AQS的acquireSharedInterruptibly方法。
acquireSharedInterruptibly方法的特点是线程获取资源的时候可以被中断,并且获取的资源是共享资源。方法内部首先判断,如果当前线程被中断了则抛出异常,否则调用sync实现的tryAcquireShared方法看当前状态值(计数器值)是否为0,是则当前线程的await()方法直接返回,否则调用AQS的doAcquireSharedInterruptibly让当前线程阻塞。
另外可知,tryAcquireShared传递的arg参数是没有用到的,调用tryAcquireShared的方法仅仅是检查当前状态值是不是0,并没有调用CAS让当前状态值减去1。
C、boolean await(long timeout, TimeUnit unit)方法
当线程调用了CountDownLatch对象的该方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:
- 当所有线程都调用了CountDownLatch对象的countDown方法后,也就是计时器值为0的时候,返回true;
- 设置的timeout时间到了,因为超时而返回false;
- 其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程会抛出InterruptedException异常后返回。
public boolean await(long timeout, TimeUnit unit) throws InterruptedException
{
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
D、void countDown()方法
当线程调用了该方法后,会递减计数器的值,递减后,如果计数器为0则会唤醒所有调用await方法而被阻塞的线程,否则什么都不做。
// syn的方法
protected boolean tryReleaseShared(int releases)
{
// 循环进行cas,直到当前线程成功完成cas使计数值(状态值state)减一并更新到state
for (; ; )
{
int c = getState();
// 如果当前状态值为0则直接返回A
if (c == 0)
{
return false;
}
// CAS设置计数值减一B
int nextc = c - 1;
if (compareAndSetState(c, nextc))
{
return nextc == 0;
}
}
}
// AQS的方法
public final boolean releaseShared(int arg)
{
// 调用sync实现的tryReleaseShared
if (tryReleaseShared(arg))
{
// AQS的释放资源方法
doReleaseShared();
return true;
}
return false;
}
// CountDownLatch的countDown()方法
public void countDown()
{
// 委托sync调用AQS的方法
sync.releaseShared(1);
}
CountDownLatch的countDown()方法是委托sync调用了AQS的releaseShared方法。releaseShared内部首先调用了sync实现的AQS的tryReleaseShared。
该方法获取当前状态值(计数器值),执行代码A,如果当前状态值为0则直接返回false,countDown()方法直接返回;否则执行代码B使用CAS设置计数器减一,CAS失败则循环重试,否则如果当前计数器为0则返回true,返回true后说明当前线程是最后一个线程调用的countdown方法,那么该线程除了让计数器值减1外,还需要唤醒调用CountDownLatch的await方法而被阻塞的线程,也就是AQS的doReleaseShared()方法。
添加代码A,是为了防止当计数器值为0后,其他线程又调用了countDown方法,如果没有代码A,状态值就会变成负数了。
E、long getCount()方法
获取当前计数器的值,也就是AQS的state的值,一般在debug测试的时候使用。
public long getCount()
{
return sync.getCount();
}
int getCount()
{
return getState();
}
内部调用AQS的getState方法来获取state的值。
三、使用CountDownLatch注意事项
使用CountDownLatch,要注意countDown()方法要在finally块内执行,避免抛异常后得不到执行。
private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main(String[] args) throws InterruptedException
{
Thread threadOne = new Thread(new Runnable()
{
@Override
public void run()
{
try
{
Thread.sleep(1000);
System.out.println("child threadOne over!");
throw new RuntimeException("error");
}
catch (InterruptedException e)
{
e.printStackTrace();
}
countDownLatch.countDown();
}
});
Thread threadTwo = new Thread(new Runnable()
{
@Override
public void run()
{
try
{
Thread.sleep(1000);
System.out.println("child threadTwo over!");
}
catch (InterruptedException e)
{
e.printStackTrace();
}
countDownLatch.countDown();
}
});
// 启动子线程
threadOne.start();
threadTwo.start();
System.out.println("wait all child thread over!");
// 等待子线程执行完毕,返回
countDownLatch.await();
System.out.println("all child thread over!");
}
代码执行结果:
wait all child thread over!
Exception in thread "Thread-0" child threadOne over!
child threadTwo over!
java.lang.RuntimeException: error
at top.cfish.java.ext.CountDownLatchDemo$1.run(CountDownLatchDemo.java:26)
at java.lang.Thread.run(Thread.java:748)
以上代码创建countDownLatch,计数器初始化为2,然后开启了两个线程,main线程调用了countDownLatch.await()会挂起,直到计数器为0时才返回。
其中threadTwo线程休眠1s后调用了countDownLatch.countDown()让计数器变为了1,而threadOne在执行countDownLatch.countDown()前抛出异常后退出了,没有让计数器变为0。所以main线程会一直阻塞到countDownLatch.await()不会返回。
所以规范规定调用countDownLatch.countDown()要在finally块内执行。
private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main(String[] args) throws InterruptedException
{
Thread threadOne = new Thread(new Runnable()
{
@Override
public void run()
{
try
{
Thread.sleep(1000);
System.out.println("child threadOne over!");
throw new RuntimeException("error");
}
catch (InterruptedException e)
{
e.printStackTrace();
}
finally
{
countDownLatch.countDown();
}
}
});
Thread threadTwo = new Thread(new Runnable()
{
@Override
public void run()
{
try
{
Thread.sleep(1000);
System.out.println("child threadTwo over!");
}
catch (InterruptedException e)
{
e.printStackTrace();
}
finally
{
countDownLatch.countDown();
}
}
});
// 启动子线程
threadOne.start();
threadTwo.start();
System.out.println("wait all child thread over!");
// 等待子线程执行完毕,返回
countDownLatch.await();
System.out.println("all child thread over!");
}
代码执行结果:
wait all child thread over!
Exception in thread "Thread-0" java.lang.RuntimeException: error
at top.cfish.java.ext.Main2$1.run(Main2.java:26)
at java.lang.Thread.run(Thread.java:748)
child threadTwo over!
child threadOne over!
all child thread over!
countDownLatch.countDown()被放到了finally块内,即使threadOne抛出了异常,由于finally块内代码总是会执行,所以运行上面的代码,main线程会正常退出。