1 案例介绍
在日常开发中经常会遇到需要在主线程中开启多个线程去并行执行任务, 并且主线程需要等待所有子线程执行完毕后再进行汇总的场景。在CountDownLatch 出现之前一般都使用线程的join( )方法来实现这一点,但是join 方法不够灵活, 不能够满足不同场景的需要,所以JDK 开发组提供了CountDownLatch 这个类。
1.1 Java中join()方法的理解
thread.Join把指定的线程加入到当前线程,可以将两个交替执行的线程合并为顺序执行的线程。比如在线程B中调用了线程A的Join()方法,直到线程A执行完毕后,才会继续执行线程B。
- t.join(); //调用join方法,等待线程t执行完毕
- t.join(1000); //等待 t 线程,等待时间是1000毫秒。
1.2 countDownLatch
package com.example.demo.typeHandler.demo2;
import java.util.concurrent.CountDownLatch;
/**
* @author wb-hll364276
* @date 2020/2/28.
*/
public class JoinCountDownLatch {
private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main(String[] args) throws InterruptedException {
Thread threadOne = new Thread(new Runnable() {
@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
countDownLatch.countDown();
}
System.out.println("one over");
}
});
Thread threadTwo = new Thread(new Runnable() {
@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
countDownLatch.countDown();
}
System.out.println("two over");
}
});
threadOne.start();
threadTwo.start();
System.out.println("wait all child over");
countDownLatch.await();
System.out.println("all child over");
}
}
在如上代码中,创建了一个CountDownLatch 实例,因为有两个子线程所以构造函数的传参为2 。主线程调用countDownLatch.await( )方法后会被阻塞。子线程执行完毕后调用countDownLatch.countDown( )方法让countDownLatch 内部的计数器减1 ,所有子线程执行完毕并调用countDown ( )方法后计数器会变为0 ,这时候主线程的await( )方法才会返回。
在项目实践中一般都避免直接操作线程,而是使用线程池来管理。使用ExecutorService 时传递的参数是Runable 或者Callable 对象,这时候你没有办法直接调用这些线程的join( )方法
1.3 countDownLatch和join的区别
调用一个子线程的join ()方法后,该线程会一直被阻塞直到子线程运行完毕,而CountDownLatch 则使用计数器来允许子线程运行完毕或者在运行中递减计数,也就是CountDownLatch 可以在子线程运行的任何时候让await 方法返回而不一定必须等到线程结束。
使用线程池来管理线程时一般都是直接添加Runable 到线程池,这时候就没有办法再调用线程的join 方法了,就
是说countDownLatch 相比join 方法让我们对线程同步有更灵活的控制。
2 原理探究
2.1 类图
从类图可以看出, CountDownLatch 是使用AQS 实现的。通过下面的构造函数,你会发现,实际上是把计数器的值赋给了AQS 的状态变量state 。
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
Sync(int count) {
setState(count);
}
2.2 重点方法
2.2.1 void await( )
当线程调用CountDownLatch 对象的await 方法后, 当前线程会被阻塞, 直到下面的情况之一发生才会返回: 当所有线程都调用了CountDownLatch 对象的countDown 方法后,也就是计数器的值为0 时;其他线程调用了当前线程的interrupt ( )方法中断了当前线程,当前线程就会抛出InterruptedException 异常, 然后返回。
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
//AQS获取共享资源、时可被中断的方法
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
- 该方法线程获取资源时可以被中断, 并且获取的资源是共享资源。
- acquireSharedInterruptibly首先判断当前线程是否己被中断, 若是则抛出异常。
- 否则调用sync 实现的tryAcquireShared 方法查看当前状态值( 计数器值)是否为0 , 是0则当前线程的await()方法直接返回。
- 否则调用AQS 的doAcquireSharedInterruptibly( )方法让当前线程阻塞。
- 另外可以看到,这里tryAcquireShared 传递的arg 参数没有被用到, 调用tryAcquireShared 的方法仅仅是为了检查当前状态值是不是为0 ,并没有调用CAS 让当前状态值减1 。
2.2.2 boolean await(long timeout, TimeUnit unit)方法
当线程调用了CountDownLatch 对象的该方法后, 当前线程会被阻塞, 直到下面的情况之一发生才会返回:
- 当所有线程者调用了CountDownLatch对象的countDown 方法后,也就是计数器值为0 时,会返回true ;
- 设置的timeout 时间到了,因为超时而返回false ;
- 其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程会抛出InterruptedException异常,然后返回。
2.2.3. void countDown()方法
线程调用该方法后,计数器的值递减, 递减后如果计数器值为0 则唤醒所有因调用await 方法而被阻塞的线程,否则什么都不做。下面看下countDown( )方法是如何调用AQS 的方法的。
public void countDown() {
//委托sync调用AQS的方法
this.sync.releaseShared(1);
}
/**
* Releases in shared mode. Implemented by unblocking one or more
* threads if {@link #tryReleaseShared} returns true.
*
* @param arg the release argument. This value is conveyed to
* {@link #tryReleaseShared} but is otherwise uninterpreted
* and can represent anything you like.
* @return the value returned from {@link #tryReleaseShared}
*/
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
//在如上代码中, releaseShared 首先调用了sync 实现的AQS 的tryReleaseShared 方法,其代码如下。
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0) (1)
return false;
int nextc = c-1; (2)
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
- 如上代码首先获取当前状态值(计数器值) 。
- 代码判断如果当前状态值为0 则直接返回false ,从而方法直接返回:
- 否则执行代码。使用CAS 将计数器值减1, CAS 失败则循环重试,否则如果当前计数器值为0 则返回true ,
- 返回true 说明是最后一个线程调用的countdown方法,那么该线程除了让计数器值减1外,还需要唤醒因调用CountDownLatch 的await 方法而被阻塞的线程,具体是调用AQS 的doReleaseShared方法来激活阻塞的线程。
- (1)处代码貌似是多余的,其实不然,之所以添加代码是为了防止当计数器值为0 后,其他线程又调用了countDown 方法,如果没有代码状态值就可能会变成负数。
3 总结
CountDownLatch,相比使用join方法来实现线程间同步,前者更具有灵活性和方便性。
CountDownLatch是使用AQS 实现的。使用AQS的state变量来存放计数器的值。先在初始化CountDownLatch 时设置状态值(计数器值),当多个线程调用countdown方法时实际是原子性递减AQS 的状态值。当线程调用await方法后当前线程会被放入AQS的阻塞队列等待计数器为0再返回。其他线程调用countdown方法让计数器值递减1,当计数器值变为0时,当前线程还要调用AQS 的doReleaseShared 方法来激活由于调用await()方法而被阻塞的线程。