从CountDownLatch来看AQS
大背景
在java1.5中引入了 CyclicBarrier、Semaphore,CountDownLatch等并发工具类,而此类工具类的实现其实都是基于大名鼎鼎的AbstractQueuedSynchronizer来实现
直接来看aqs可能较为抽象,我们可以先看一个简单的实现吧,这里就用countdownlatch来开始入手
一、CountDownLatch的概念
- countDownLatch这个类使一个线程等待其他线程各自执行完毕后再执行。
- 是通过一个计数器来实现的,计数器的初始值是线程的数量。每当一个线程执行完毕后,计数器的值就-1,当计数器的值为0时,表示所有线程都执行完毕,然后在闭锁上等待的线程就可以恢复工作了。
构造方法
//参数count为计数值
public CountDownLatch(int count) { };
类的API
//调用await()方法的线程会被挂起,它会等待直到count值为0才继续执行
public void await() throws InterruptedException { };
//和await()类似,只不过等待一定的时间后count值还没变为0的话就会继续执行
public boolean await(long timeout, TimeUnit unit) throws InterruptedException { };
//将count值减1
public void countDown() { };
示例代码如下:
@Slf4j
public class CountDownLatchExample1 {
private static final Logger log = LoggerFactory.getLogger(CountDownLatchExample1.class);
private final static int threadCount=10;
public static void main(String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newCachedThreadPool();
final CountDownLatch countDownLatch=new CountDownLatch(threadCount);
// final MyCountDownLatch countDownLatch = new MyCountDownLatch(threadCount);
for (int i = 0; i < threadCount; i++) {
final int threadNum=i;
executorService.execute(()->{
try {
test(threadNum);
} catch (InterruptedException e) {
log.error("exception",e );
}finally {
countDownLatch.countDown();
}
});
}
countDownLatch.await();
log.info("finish" );
executorService.shutdown();
}
public static void test(int threadNum) throws InterruptedException {
Thread.sleep(100);
log.info("{}",threadNum );
Thread.sleep(100);
}
}
输出值:
17:39:57.657 [pool-1-thread-1] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 0
17:39:57.657 [pool-1-thread-7] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 6
17:39:57.657 [pool-1-thread-9] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 8
17:39:57.657 [pool-1-thread-6] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 5
17:39:57.657 [pool-1-thread-10] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 9
17:39:57.657 [pool-1-thread-2] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 1
17:39:57.657 [pool-1-thread-5] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 4
17:39:57.657 [pool-1-thread-3] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 2
17:39:57.657 [pool-1-thread-4] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 3
17:39:57.657 [pool-1-thread-8] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - 7
17:39:57.760 [main] INFO com.example.aqs.countdownlatch.CountDownLatchExample1 - finish
CountDownLatch也有叫门闩锁,类似于大家出去玩,约定某个地点集合,人到齐之后方可出发,比如约定是5人,每到一个人就会将数量减一,直到数量为0后方可出发;
二、AQS概念
Aqs的全称AbstractQueuedSynchronizer,也有称为队列同步器,是juc包的基础工具类,不少并发工具类都是基于它来实现,我们这次就以countdownlatch来举例进行剖析
AbstractQueuedSynchronizer的核心分为三部分:
-
state AbstractQueuedSynchronizer的成员变量
-
控制线程抢锁和配合的FIFO队列
-
期待工具类去实现的获取/释放的重要方法(tryAcquire/tryRelease/tryAcquireShared/tryReleaseShared)
很重要!!!如果上面的countDownLatch熟悉了,那我们继续往下:
我们主要分析一下await方法和countDown方法
前面有说是基于AbstractQueuedSynchronizer来实现CountDownLatch,我们在源码中可以发现Sync这个内部类
// 源码 private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L; // Sync(int count) { setState(count); } int getCount() { return getState(); } protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } }
如果感兴趣的话,可以看一下AbstractQueuedSynchronizer在哪些工具类中被继承,这里就不做展开了,其实这就是aqs的核心,当然现在看来还比较抽象,继续往下:
从CountDownLatch来着手分析入门AQS
一、CountDownLatch的构造方法:
public CountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count); }
我们可以发现,其实本质就是对内部类实例的初始化:
内部类的构造方法 Sync(int count) { setState(count); } 那我们继续往下看 protected final void setState(int newState) { state = newState; } 我们到aqs的源码中查看state这个成员变量,这个state便是我们所说的AQS的三大核心之一 /** * The synchronization state. */ private volatile int state;
总结一下,其实构造CountDownLatch的本质就是对AQS的state进行初始化赋值
二、CountDownLatch的await()方法
其实是在调用aqs的获取方法 public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); }
然后继续往下debug
public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); if (tryAcquireShared(arg) < 0) // tryAcquireShared doAcquireSharedInterruptibly(arg); //加入等待队列,这个队列就是三大核心之一的FIFO 队列 }
我们往下看tryAcquireShared(arg) < 0这个方法源码
protected int tryAcquireShared(int arg) { throw new UnsupportedOperationException(); }
这里便是用户需要自定义实现的方法,我们可以看下这个方法在CountDownLatch中的实现
protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; //如果state = 0,那么我们就不做任何操作,如果不等 于0,就会执行 doAcquireSharedInterruptibly(arg); //加入等待队列了 }
以上就可以看出state便是控制线程调度逻辑的核心
三、CountDownLatch的countDown()方法
其实是在调用aqs的方法
public void countDown() {
sync.releaseShared(1);
}
老规矩,继续往下
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) { //尝试释放.返回true之后唤醒所有等待对列中的线程
doReleaseShared();
return true;
}
return false;
}
那我们看下tryReleaseShared(arg)的源码
protected boolean tryReleaseShared(int arg) {
throw new UnsupportedOperationException();
}
这里便是用户需要自定义实现的方法,我们可以看下这个方法在CountDownLatch中的实现
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) { //无限循环
int c = getState();
if (c == 0) // 如果当前的state = 0,那么本次操作就为false
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc)) //这里就是调用AQS的cas方法,更新state值
return nextc == 0; //成功之后,如果state == 0 返回true去唤醒队列中的线程
}
}
我们可以扩展看一下CAS的实现方式
protected final boolean compareAndSetState(int expect, int update) {
// See below for intrinsics setup to support this
return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
}
stateOffset = unsafe.objectFieldOffset
(AbstractQueuedSynchronizer.class.getDeclaredField("state"));
目标地址值,期望值,更新值来实现一个无锁的线程安全操作,但是这个好像解决不了ABA的问题,这里先不做扩展了
可以看到这里线程会无限循环去尝试更新state值,除非state==0
总结一下:
CountDownLatch的大致过程就是:
- 构造CountDownLatch的时候,初始化AQS的state,
- 然后线程await()的时候会去查看state是否等于0,如果不为0,将当前线程加入到AQS的等待对列中;
- 线程执行countDown()的时候,采用CAS的方法去更新state值,一旦state = 0,唤醒队列中的所有线程