CountDownLatch
java.util.concurrent.CountDownLatch
发令枪,允许一个或多个线程等待其他线程完成操作
主线程需要等待所有的子线程执行完后进行汇总,join
方法可以实现这一点,但是不够灵活。
使用
public class CountDownLatchTest {
// 计数器
private static CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main(String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
executorService.submit(()->{
try {
TimeUnit.SECONDS.sleep(3);
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
// 计数器减一
countDownLatch.countDown();
}
System.out.println("Thread A Over");
});
executorService.submit(()->{
try {
TimeUnit.SECONDS.sleep(2);
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
countDownLatch.countDown();
}
System.out.println("Thread B Over");
});
// 等待子线池执行完
countDownLatch.await();
System.out.println("All Child thread Over!");
executorService.shutdown();
}
}
源码探究
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 循环进行CAS,直到当前线程成功完成CAS使计数器值(state)减1,并更新到state
for (;;) {
int c = getState();
// 如果当前状态值为0则直接返回
if (c == 0)
return false;
// 使用CAS让计数器值减一
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
// 内部使用AQS实现
private final Sync sync;
// 构造函数 传入计数器
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
// 把计数器器的值赋予AQS的状态变量state
this.sync = new Sync(count);
}
// 线程调用await方法,当前线程会被阻塞,直到下面情况发生才返回
// 1. 所有线程调用countDown方法,计数器的值为0
// 2. 其他线程调用用当前线程的interrupte()方法中断了当前线程,当前线程抛出InterruptedException返回
public void await() throws InterruptedException {
// 委托sysnc调用AOS的acquireSharedInterruptibly方法
sync.acquireSharedInterruptibly(1);
}
// 带有超时时间阻塞等待方法
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// 计算器的值递减,递减为0则唤醒所有调用await方法而被阻塞的线程
public void countDown() {
sync.releaseShared(1);
}
// 返回当前计算器值
public long getCount() {
return sync.getCount();
}
}
总结: 相比使用join方法实现线程间同步,CountDownLatch更具有灵活性和方便性,CountDownLatch使用AQS实现的,使用AQS的状态变量来存放计数器的值,构造函数初始化状态值(计数器值),当线程调用await()方法后当前线程会被放入AQS阻塞队列等待计数器为0在返回,多个线程调用countDown()时原子(cas)递减AQS的状态值,计数器值减1,当计数器值变为0是,当前线程调用AQS的doReleaseShared方法激活由调用await()而被阻塞的线程
CyclicBarrier
java.util.concurrent.CyclicBarrier
可循环(Cyclic)的屏障(Barrier)。回环屏障,让一组线程到达屏障(屏障点/同步点)时,屏障才会打开,所有被屏障拦截的线程才会继续工作
使用
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class CyclicBarrierTest {
private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2,()->{
System.out.println(Thread.currentThread()+"=====达到屏障点执行=");
// todo something
});
public static void main(String[] args) {
ExecutorService executorService = Executors.newFixedThreadPool(2);
// 线程A
executorService.submit(()->{
try {
System.out.println(Thread.currentThread()+"step1");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step2");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step3");
} catch (Exception e) {
e.printStackTrace();
}
});
// 线程B
executorService.submit(()->{
try {
System.out.println(Thread.currentThread()+"step1");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step2");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step3");
} catch (Exception e) {
e.printStackTrace();
}
});
executorService.shutdown();
}
}
注意:
- 对于指定计数器值parties,若由于某种原因,没有足够的线程调用CyclicBarrier的await,则所有调用await的线程都会被阻塞
- CyclicBarrier也可以调用await(timeout,unit)设置超时时间,在设定时间内,如果没有足够线程到达,则解除阻塞状态,继续工作
- 通过reset重置计数,会使得进入await的线程出现
BrokenBarrierException
4)如果采用是CyclicBarrier(int parties,Runnable barrierAction)
构造方法,执行barrierAction
操作的是最后一个到达的线程
源码探究
CyclicBarrier
基于ReetrantLock
和Condition
实现,CyclicBarrier 可以有不止一个栅栏,因为它的栅栏(Barrier)可以重复使用(Cyclic)
public class CyclicBarrier {
private static class Generation {
//记录当前屏障是否被打破 由于在锁内使用,所以不需要申明volatile
boolean broken = false;
}
/** The lock for guarding barrier entry */
private final ReentrantLock lock = new ReentrantLock();
/** Condition to wait on until tripped */
// lock的条件变量支持线程间使用await和sigal操作进行同步
private final Condition trip = lock.newCondition();
/** The number of parties
* 记录线程的个数,表示多少个线程调用await后,所有线程才会冲破屏障继续往下运行
* */
private final int parties;
/* The command to run when tripped */
// 当所有的线程到达了屏障点,最后一个线程执行
private final Runnable barrierCommand;
/** The current generation */
private Generation generation = new Generation();
/** count一开始等于parties,每当线程调用await方法就递减1,当count为0就表示所有线程到了屏障点*/
private int count;
/**
* Updates state on barrier trip and wakes up everyone.
* Called only while holding lock.
*/
private void nextGeneration() {
// signal completion of last generation 唤醒条件队列里面阻塞线程
trip.signalAll();
// set up next generation 重置CyclicBarrier
count = parties;
generation = new Generation();
}
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
// dowait实现了CyclicBarrier的核心功能
// timed 是否设置超时时间
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
int index = --count;
// (1) index==0则说明所有线程都到了屏障点,此时执行初始化时传递的任务时传递任务
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
// (2) 执行任务
if (command != null)
command.run();
ranAction = true;
// (3) 激活其他因调用await方法而被阻塞的线程,并重置CyclicBarrier
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
// (4) 如果index != 0
for (;;) {
try {
// (5) 没有设置超时时间
if (!timed)
trip.await();
// (6) 设置了超时时间
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
public CyclicBarrier(int parties) {
this(parties, null);
}
public int getParties() {
return parties;
}
// 阻塞方法
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
// 待超时时间的阻塞方法
// parties个线程都调用了await()方法,也就是线程都到了屏障点,这时候返回true;
// 设置的超时时间到了后返回false;
// 其他线程调用当前线程的interrupt()方法中断了当前线程,则当前线程会抛出InterruptedException异常然后返回;
// 与当前屏障点关联的Generation对象的broken标志被设置为true时,会抛出BrokenBarrierException异常,然后返回
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
// 重置Barrier
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}
Semaphore
Semaphore(信号量)是用来控制同时访问特定资源的线程数量,它通过协调各个线程,以保证合理的使用公共资源
使用
场景: Semaphore可以用于做流量控制,特别是公用资源有限的应用场景,比如数据库连接。
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
public class SemaphoreTest {
static class Car extends Thread {
private int num;
private Semaphore semaphore;
public Car(int num, Semaphore semaphore) {
this.num = num;
this.semaphore = semaphore;
}
@Override
public void run() {
try {
// 获得一个令牌,如果拿不到令牌,就会阻塞
semaphore.acquire();
System.out.println("第"+num+" 抢占一个车位");
TimeUnit.SECONDS.sleep(2);
System.out.println("第"+num+" 开走");
semaphore.release();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public static void main(String[] args) {
Semaphore semaphore = new Semaphore(5);
// ExecutorService executorService = Executors.newFixedThreadPool(5);
for (int i = 0; i < 10; i++) {
new Car(i,semaphore).start();
// executorService.submit(new Car(i,semaphore));
}
// executorService.shutdown();
}
}
源码探究
public class Semaphore implements java.io.Serializable {
private static final long serialVersionUID = -3222578661600680210L;
private final Sync sync;
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
Sync(int permits) {
setState(permits);
}
final int getPermits() {
return getState();
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
// 获取当前信号量值
int available = getState();
// 计算当前剩余值
int remaining = available - acquires;
// 如果剩余值小于0说明当前信号量个数不满足需求
// 大于0且CAS操作成功,返回剩余值
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
int current = getState();
int next = current + releases;
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
if (compareAndSetState(current, next))
return true;
}
}
final void reducePermits(int reductions) {
for (;;) {
int current = getState();
// 不允许缩减值为负数
int next = current - reductions;
if (next > current) // underflow
throw new Error("Permit count underflow");
if (compareAndSetState(current, next)) // CAS 设置缩减后的许可证数量
return;
}
}
final int drainPermits() {
for (;;) {
int current = getState();
// 直接把剩余的许可证数量设置为0
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
// 公平策略
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
// 如果当前线程不位于对头,则阻塞
// hasQueuedPredecessors 来保证公平性
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
// 获取一个信号量值,未获取到会阻塞
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public void acquireUninterruptibly() {
sync.acquireShared(1);
}
public boolean tryAcquire() {
return sync.nonfairTryAcquireShared(1) >= 0;
}
public boolean tryAcquire(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void release() {
sync.releaseShared(1);
}
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
public void acquireUninterruptibly(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireShared(permits);
}
public boolean tryAcquire(int permits) {
if (permits < 0) throw new IllegalArgumentException();
return sync.nonfairTryAcquireShared(permits) >= 0;
}
public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
}
// 释放信号量
public void release(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.releaseShared(permits);
}
// 返回信号量中对当前可用的许可证数
public int availablePermits() {
return sync.getPermits();
}
// 获取立即可用的所有许可证个数,并将可用许可证置0
public int drainPermits() {
return sync.drainPermits();
}
protected void reducePermits(int reduction) {
if (reduction < 0) throw new IllegalArgumentException();
sync.reducePermits(reduction);
}
// 是否是公平策略
public boolean isFair() {
return sync instanceof FairSync;
}
// 是否有线程正在等待获取许可证
public final boolean hasQueuedThreads() {
return sync.hasQueuedThreads();
}
// 返回正则等待许可证的线程数
public final int getQueueLength() {
return sync.getQueueLength();
}
// 返回所有等待许可证的线程集合
protected Collection<Thread> getQueuedThreads() {
return sync.getQueuedThreads();
}
public String toString() {
return super.toString() + "[Permits = " + sync.getPermits() + "]";
}
}
Exchanger
Exchanger(交换者)是一个用于线程间协作的工具类。Exchanger用于进行线程间的数据交换。它提供一个同步点,在这个同步点,两个线程可以交换彼此的数据。这两个线程通过exchange方法交换数据,如果第一个线程先执行exchange()方法,它会一直等待第二个线程也执行exchange方法,当两个线程都到达同步点时,这两个线程就可以交换数据,将本线程生产出来的数据传递给对方
import java.util.concurrent.Exchanger;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class ExchangerTest {
private static final Exchanger<String> exgr = new Exchanger<>();
public static void main(String[] args) {
ExecutorService service = Executors.newFixedThreadPool(2);
service.submit(()->{
try {
String A = "银行流水A";
System.out.println(Thread.currentThread()+"交换前:"+A);
// 同步点:等待另一个线程到达此交换点(除非当前线程被中断),然后将给定的对象传送给该线程,并接收该线程的对象
String data = exgr.exchange(A);
System.out.println(Thread.currentThread()+"交换后:"+data);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
service.submit(()->{
try {
String B = "银行流水B";
System.out.println(Thread.currentThread()+"交换前:"+B);
String data = exgr.exchange(B);
System.out.println(Thread.currentThread()+"交换前:"+data);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
service.shutdown();
}
}