CyclicBarrier和CountDownLatch
CyclicBarrier
Cyclic是循环的意思,Barrier是屏障的意思,所以从其中文意思上可以理解为一个循环屏障,也就是说当满足一个条件之前,都是等待状态,只有当所有线程都满足了,在进行下一步操作。
举个例子就是:我们去游乐场坐过山车,当座位都满了的时候,工作人员才会发车,要不然所有人都在等待发车(不考虑游客少的时候,人不满也会发车)。
使用方法
首先我们先来看看它的使用方法:
package com.example.demo.basis.thread.lock;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
/**
* @author : sun
* create at: 2020/5/14 21:03
* @description: CyclicBarrier
* 使用方法:CyclicBarrier有两个构造,两个构造参数是:前面的参数表示线程目标数,当await的线程到达这个数的时候,执行后面Runnable
* 里面的run方法,同时执行每个线程接下来的代码;一个参数就是:指定目标线程数,达到这个数后没有任何操作.
* 注意事项:await的时当前线程;await会造成阻塞
*/
public class TestCyclicBarrier {
static CyclicBarrier cyclicBarrier = new CyclicBarrier(10,
new Runnable() {
@Override
public void run() {
System.out.println("有10个人了,发车");
}
});
public static void main(String[] args) throws BrokenBarrierException, InterruptedException {
for (int i = 0; i < 100; i++) {
int a = i;
new Thread(()->{
try {
cyclicBarrier.await();
System.out.println("执行到 : " + a);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
}).start();
}
/**
* 下面的代码是不起作用的,因为它只有主线程await了,所以使用的时候需要注意
*/
// for (int i = 0; i < 100; i++) {
// cyclicBarrier.await();
// }
}
}
执行结果我这边就不放了,有兴趣的可以copy出来执行一下。
await()方法会阻塞当前线程;
当等待的线程数满足阈值后就会执行相应的方法,且继续执行被阻塞线程里的代码;
这是一个循环,如果将for里面的100替换成101,你会发现程序会一直处于运行状态;
源码分析
package java.util.concurrent;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
/**
* 首先从源码可以看出,CyclicBarrier不是基于AQS
**/
public class CyclicBarrier {
/**
* 年代类——对循环进行分代,同一代共用一个标识位;在CyclicBarrier中,同一批线程属于同一代
*/
private static class Generation {
boolean broken = false;
}
//锁
private final ReentrantLock lock = new ReentrantLock();
/** Condition to wait on until tripped */
private final Condition trip = lock.newCondition();
/** The number of parties */
private final int parties;
/* The command to run when tripped */
private final Runnable barrierCommand;
/** The current generation */
private Generation generation = new Generation();
/**
* 计算线程数
*/
private int count;
/**
* 重置,开启下一代
*/
private void nextGeneration() {
// signal completion of last generation
trip.signalAll();
// set up next generation
count = parties;
generation = new Generation();
}
/**
* 如果有线程中断,则将这一代标识修改为true,同时唤醒condition中的线程
*/
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
/**
* 这个是它最主要的方法
*/
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
//锁上
lock.lock();
try {
//获得年代
final Generation g = generation;
//判断当前generation的状态,如果这个年代的中有些线程已经将CyclicBarrie“打断”,则抛出异常
if (g.broken)
throw new BrokenBarrierException();
//这里就是校验线程状态,如果线程状态处于中断,则调用breakBarrier方法,将CyclicBarrie的年代标识设置为true
//对应上面的校验
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
//count--
int index = --count;
//当count=0的时候,说明可以打开阻塞触发barrierCommand任务了
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
//这个方法里会唤醒trip中所有等待的线程,同时重置generation——下一代
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// 如果count不为0,自旋
for (;;) {
try {
//这里涉及Condition知识,在我其他文章中有介绍
//如果没有设置等待时间,则调用Condition.await()方法
if (!timed)
trip.await();
else if (nanos > 0L)//如果设置了超时时间,调用Condition.awaitNanos()方法
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
//generation已经更新,返回index
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();//释放锁
}
}
/**
* 构造方法,初始化参数
*/
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
/**
* 构造方法,调用上面的构造
*/
public CyclicBarrier(int parties) {
this(parties, null);
}
/**
* 获得设置的阈值——栅栏数
*/
public int getParties() {
return parties;
}
/**
* 给外部调用的await方法,没有超时时间
*/
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
/**
* 给外部调用的await方法,有超时时间
*/
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
/**
* 查看当前代是否被中断
*/
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
/**
* 重置barrier到初始化状态
*/
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
/**
* 获得等待的线程数
*/
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}
CountDownLatch
CountDownLatch是JUC包里提供的一个并发编程的工具类,它的实现也是基于AQS,其内部通过state来计数,当这个数变成0(查看它的构造函数,初始值不可以小于0,否则抛出异常)等待的线程就可以继续执行。
使用方法
首先我们先来看看它的使用方法:
package com.example.demo.basis.thread.lock;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
* @author : sun
* create at: 2020/5/14 19:32
* @description: CountDownLatch
* 使用方法:首先初始化是指定大小,可以理解为指定座位数,一个线程执行完业务步骤后调用countDown方法,当前线程就类似于做到座位上了,
* 在主线程里调用await方法,意思就是等座位上的所有线程执行完毕后,执行下面的代码,这里会造成阻塞。
* 在使用的时候需要注意:在指定CountDownLatch初始参数的时候一定要预判好有几个位置,因为只有位置坐满了,并且每个位置的线程都执行
* 结束,才会执行await后面的代码。
*/
public class TestCountDownLatch {
static CountDownLatch countDownLatch = new CountDownLatch(3);
public static void main(String[] args) {
long startTime = System.currentTimeMillis();
new Thread(()->{
try {
TimeUnit.SECONDS.sleep(5);
countDownLatch.countDown();
System.out.println("1");
} catch (InterruptedException e) {
e.printStackTrace();
}
},"线程1").start();
new Thread(()->{
try {
TimeUnit.SECONDS.sleep(8);
countDownLatch.countDown();
System.out.println("2");
} catch (InterruptedException e) {
e.printStackTrace();
}
},"线程2").start();
new Thread(()->{
try {
TimeUnit.SECONDS.sleep(10);
countDownLatch.countDown();
System.out.println("3");
} catch (InterruptedException e) {
e.printStackTrace();
}
},"线程3").start();
new Thread(()->{
try {
TimeUnit.SECONDS.sleep(20);
countDownLatch.countDown();
System.out.println("4");
} catch (InterruptedException e) {
e.printStackTrace();
}
},"线程4").start();
new Thread(()->{
try {
//这里试验了CountDownLatch是可以唤醒多个等待的线程的
countDownLatch.await();
System.out.println(Thread.currentThread().getName() + " 执行到了 !");
} catch (InterruptedException e) {
e.printStackTrace();
}
},"线程5").start();
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
long endTime = System.currentTimeMillis();
System.out.println("一共耗时 : " + (endTime-startTime));
}
}
结果:
1
2
3
一共耗时 : 10029
线程5 执行到了 !
4
我们从结果来分析:
CountDownLatch.countDown()方法不会阻塞当前线程,只会阻塞调用await()方法的地方。
CountDownLatch是不可重入的,只会生效一次
CountDownLatch会唤醒多个await()方法的地方
源码分析
package java.util.concurrent;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
public class CountDownLatch {
/**
* 首先这个内部类实现了AbstractQueuedSynchronizer,说明CountDownLatch也是基于AQS实现的
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
//构造方法,调用的setState方法是父类的final类型的方法,主要是设置state值
Sync(int count) {
setState(count);
}
//获得state值,getState也是父类final类型的方法
int getCount() {
return getState();
}
//wait方法最终会调用到这个方法,来获得state;如果为0返回1,其他返回-1。
//注意:入参没有用到
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
//核心方法——自旋,验证state
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
//CountDownLatch的属性
private final Sync sync;
/**
* 这个CountDownLatch的唯一一个构造函数,参数为int类型
* 必须指定一个初始count,也就是目标计数,当达到这个目标值后,才会放开阻塞
*/
public CountDownLatch(int count) {
//如果count小于0,抛出异常
if (count < 0) throw new IllegalArgumentException("count < 0");
//初始化Sync,跟踪代码这里最终设置了父类的state这个值,state是volatile修饰的,
this.sync = new Sync(count);
}
/**
* 首先这个方法会调用AbstractQueuedSynchronizer的acquireSharedInterruptibly,入参为1,仔细阅读源码发现,这个1其实没有用到
* 下面是acquireSharedInterruptibly 方法:
* public final void acquireSharedInterruptibly(int arg)
* throws InterruptedException {
* 判断当前线程状态,这个线程是调用await方法的线程!
* if (Thread.interrupted())
* throw new InterruptedException();
* 这里调用的是tryAcquireShared方法,回到子类sync的tryAcquireShared方法(参考上面Sync内的tryAcquireShared)
* tryAcquireShared(arg),其实就是判断state这个值,如果为0返回1,其他返回-1。
* 返回1,简单来说就是不需要阻塞了,直接执行await后面的代码;
* 返回-1,则回去调用doAcquireSharedInterruptibly方法,这个方法简单的来说就是把当前线程封装成一个node,放到一个线程节点队列里,然后自旋
* 来调用tryAcquireShared方法,这里是从头部节点一个一个来判断的,当头节点通过,才会轮到下一个节点,直到所有线程节点都执行完成
* if (tryAcquireShared(arg) < 0)
* doAcquireSharedInterruptibly(arg);
* }
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
/**
* 这个方法和await方法类似,只是它加了一个超时时间限制,第一个参数是时间,第二个参数是时间单位
* 如果超过这个时间就不会继续等待了,其实源码和上述的doAcquireSharedInterruptibly中,自旋里面多了个是时间判断
* 超过这个时间就会跳出自旋,不会让当前线程继续等待了。
*/
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
/**
* 调用AbstractQueuedSynchronizer的releaseShared方法:
* public final boolean releaseShared(int arg) {
* if (tryReleaseShared(arg)) {
//对于CountDownLatch只有当state最终为0的时候才会执行doReleaseShared方法;
//只用当state为0,才需要去唤醒等待的线程,所以doReleaseShared概述的来讲就是去唤醒等待的线程队列中的线程
* doReleaseShared();
* return true;
* }
* return false;
* }
* 由上述方法可以看到会调用子类的tryReleaseShared方法,Sync的tryReleaseShared参考上面说明.
*
* 需要注意的是countDown方法没有加锁!
*/
public void countDown() {
sync.releaseShared(1);
}
/**
* 获得state值.
*/
public long getCount() {
return sync.getCount();
}
/**
* toString方法
*/
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
对于CountDownLatch的使用场景我这里有个例子:我们项目中对于大文件的下载就会通过CountDownLatch来操作。