CountDownLatch
源码
package java.util.concurrent;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
public class CountDownLatch {
// 同步锁,继承AQS抽象同步器
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 使用AQS的state成员来表示count
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// count--
for (;;) {
int c = getState();
// 如果count已经为0了,则不做操作
if (c == 0)
return false;
// 基于CAS更新count
int nextc = c - 1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final java.util.concurrent.CountDownLatch.Sync sync;
// 构造
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new java.util.concurrent.CountDownLatch.Sync(count);
}
// 调用await()方法阻塞线程,默认没有实现限制
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// 可以指定阻塞的时间
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// 调用countDown方法,count--
public void countDown() {
sync.releaseShared(1);
}
// 获取count
public long getCount() {
return sync.getCount();
}
//tostring
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
1. CountDownLatch可以看作时一个计数器,每执行一次countDown()方法,count减一,当count=0,唤醒所有await()阻塞的线程继续执行;
2. count的线程安全由AQS保证。
模拟场景:10个线程任务,保证前5个线程先执行完再执行后5个线程
public class TestCount {
public static void main(String[] args) {
CountDownLatch lockcount = new CountDownLatch(5); //计数执行完的前半部分线程
CountDownLatch lockend = new CountDownLatch(5); //计数执行完的后半部分线程
for(int i=0;i<10;i++){
int finalI = i;
new Thread(()->{
try {
if(finalI<5) {
System.out.println("需要先执行的线程任务"+finalI);
lockcount.countDown(); //CountDownLatch没执行一次countDown()方法,count减一
}else{
lockcount.await(); //阻塞当前线程,当count=0时被唤醒
System.out.println("后执行的线程任务"+finalI);
lockend.countDown();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
}).start();
}
try {
lockend.await();
System.out.println("---------所有线程任务完成----------");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
CyclicBarrier
源码
package java.util.concurrent;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
public class CyclicBarrier {
// 代,屏障前或者屏障后被定义为两代
private static class Generation {
Generation() {} // prevent access constructor creation
boolean broken; // initially false
}
//同步锁
private final ReentrantLock lock = new ReentrantLock();
//条件队列,存储屏障阻塞的线程
private final Condition trip = lock.newCondition();
//屏障个数
private final int parties;
//所有屏障打破换代时先执行的线程
private final Runnable barrierCommand;
//当前代
private Generation generation = new Generation();
//计数器,初始等于parties,调用await()方法减一,等于0时打破所有屏障,唤醒所有线程
private int count;
//进入下一代,count又重新等于parties,相当于开始新的一轮,实现CyclicBarrier的可重用
private void nextGeneration() {
// signal completion of last generation
trip.signalAll();
// set up next generation
count = parties;
generation = new Generation();
}
//打破所有屏障,初始化count并唤醒所有阻塞的线程
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
// 构造器, 初始化parties,并且可以指定换代时先执行的线程barrierCommand,可以barrierCommand可以用来做通知线程
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
// 不指定则默认barrierCommand为空
public CyclicBarrier(int parties) {
this(parties, null);
}
// 获取parties
public int getParties() {
return parties;
}
// 调用dowait()方法
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
// 可以指定阻塞时间
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
//
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
// 拿到同步锁
final ReentrantLock lock = this.lock;
lock.lock();
try {
// 获取当前代
final Generation g = generation;
// 判断当前代是否可以穿越屏障,如果可以,屏障失效,抛出异常
if (g.broken)
throw new BrokenBarrierException();
// 判断当前线程是否中断,如果中断,打破屏障,唤醒所有阻塞线程,并抛出中断异常
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
// count-1
int index = --count;
// 当count减到0,
if (index == 0) { // tripped
boolean ranAction = false;
try {
// 如果定义了换代线程,则执行,
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
// 并唤醒所有阻塞线程,初始化count=parties,开始新的代
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// 当count不为0
for (;;) {
try {
// 判断是否定时阻塞
if (!timed)
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
// 判断当前代是否可以穿越屏障
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
// 重置
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
// 获取当前阻塞线程个数
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}
1. CyclicBarrier可以理解为设置parties个barrier屏障,初始化count=parties,在任务分界处调用await()方法插入屏障,同时count减一,当所有任务都执行到了分界处,count=0,此时唤醒所有阻塞的线程继续执行
2. CyclicBarrier可重用
3. 使用ReentrantLock保证count等的线程安全
4. 使用condition条件队列存储阻塞线程
模拟场景:10个线程任务,每个任务被分为前后部分,保证所有任务的前部分都执行完再开始执行所有任务的后部分,并使用一个通知线程通知前部分任务全部执行完成
public class TestCount {
public static void main(String[] args) {
CyclicBarrier lockparties = new CyclicBarrier(10, () -> {
System.out.println("所有需要先执行的部分任务全部执行了");
}); //初始count=partiers
CountDownLatch lockend2 = new CountDownLatch(10);
for(int i=0;i<10;i++){
int finalI = i;
new Thread(()->{
try {
synchronized (TestCount.class) {
System.out.println("需要先执行的前半部分任务" + finalI + " " + lockparties.getParties() + " " + lockparties.getNumberWaiting());
}
lockparties.await(); //阻塞当前线程,并且count减一
synchronized (TestCount.class) {
System.out.println("后执行的后半部分任务" + finalI + " " + lockparties.getParties() + " " + lockparties.getNumberWaiting());
}
lockend2.countDown();
} catch (BrokenBarrierException | InterruptedException e) {
e.printStackTrace();
}
}).start();
}
try {
lockend2.await();
System.out.println("---------所有线程任务完成----------");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}