使用
public static void main(String[] args) {
CyclicBarrier cyclicBarrier = new CyclicBarrier(3);
System.out.println("hahaha");
for (int i=0; i<3; i++) {
new Worker(cyclicBarrier).start();
}
System.out.println("shutdown");
}
static class Worker extends Thread {
private CyclicBarrier cyclicBarrier ;
public Worker(CyclicBarrier cyclicBarrier) {
this.cyclicBarrier = cyclicBarrier;
}
@Override
public void run() {
try {
System.out.println("线程 " + Thread.currentThread().getName() + "开始执行");
cyclicBarrier.await();
System.out.println("线程 " + Thread.currentThread().getName() + "执行结束");
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
}
}
CyclicBarrier会将线程都等待在await的地方,等计数为0后,瞬间释放
源码分析
public class CyclicBarrier {
private static class Generation {
boolean broken = false;
}
private final ReentrantLock lock = new ReentrantLock();
private final Condition trip = lock.newCondition();
private final int parties;
private final Runnable barrierCommand;
private Generation generation = new Generation();
//传入一个计数
public CyclicBarrier(int parties) {
this(parties, null);
}
// barrierAction会在 await释放线程时执行
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
//等待节点
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
//设置break标志,重设计数,唤醒所有被condition await的线程
breakBarrier();
throw new InterruptedException();
}
//减少计数
int index = --count;
//计数为0
if (index == 0) { // tripped
boolean ranAction = false;
try {
//执行构造时传入的执行命令
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
//将所有condition队列的线程全部移入aqs队列,重设计数,重新设置break标志
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// 循环直到执行完成或者break或者中断
for (;;) {
try {
//将线程放入condition队列
if (!timed)
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
//设置break标志,重设计数,唤醒所有被condition await的线程
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
//唤醒所有被condition await的线程,重设计数,重新设置break标志
private void nextGeneration() {
// signal completion of last generation
trip.signalAll();
// set up next generation
count = parties;
generation = new Generation();
}
}
public abstract class AbstractQueuedSynchronizer
extends AbstractOwnableSynchronizer
implements java.io.Serializable {
public class ConditionObject implements Condition, java.io.Serializable {
private static final int REINTERRUPT = 1;
private static final int THROW_IE = -1;
//唤醒所有线程
public final void signalAll() {
if (!isHeldExclusively())
throw new IllegalMonitorStateException();
Node first = firstWaiter;
if (first != null)
doSignalAll(first);
}
private void doSignalAll(Node first) {
lastWaiter = firstWaiter = null;
do {
Node next = first.nextWaiter;
first.nextWaiter = null;
//将节点从condition队列取出 入队 aqs队列,并设置为signal状态,
transferForSignal(first);
first = next;
} while (first != null);
}
//
public final void await() throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
//将节点加入condition队列
Node node = addConditionWaiter();
//释放,让节点去争抢锁
int savedState = fullyRelease(node);
int interruptMode = 0;
//不在aqs队列
while (!isOnSyncQueue(node)) {
LockSupport.park(this);
if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
break;
}
//争抢锁 抢到锁或者挂起
if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
interruptMode = REINTERRUPT;
if (node.nextWaiter != null) // clean up if cancelled
unlinkCancelledWaiters();
if (interruptMode != 0)
reportInterruptAfterWait(interruptMode);
}
private Node addConditionWaiter() {
Node t = lastWaiter;
// 如果condition队列的的尾节点被cancel了,将其移除
if (t != null && t.waitStatus != Node.CONDITION) {
unlinkCancelledWaiters();
t = lastWaiter;
}
//将当前线程包装成节点,状态是condition
Node node = new Node(Thread.currentThread(), Node.CONDITION);
//将节点放入condition队列
if (t == null)
firstWaiter = node;
else
t.nextWaiter = node;
lastWaiter = node;
return node;
}
private void unlinkCancelledWaiters() {
Node t = firstWaiter;
Node trail = null;
while (t != null) {
Node next = t.nextWaiter;
if (t.waitStatus != Node.CONDITION) {
t.nextWaiter = null;
if (trail == null)
firstWaiter = next;
else
trail.nextWaiter = next;
if (next == null)
lastWaiter = trail;
}
else
trail = t;
t = next;
}
}
final boolean isOnSyncQueue(Node node) {
//condition队列是单向链表,没有prev 这个判断是在condition队列中
if (node.waitStatus == Node.CONDITION || node.prev == null)
return false;
if (node.next != null)
return true;
return findNodeFromTail(node);
}
private boolean findNodeFromTail(Node node) {
Node t = tail;
for (;;) {
if (t == node)
return true;
if (t == null)
return false;
t = t.prev;
}
}
private void reportInterruptAfterWait(int interruptMode)
throws InterruptedException {
if (interruptMode == THROW_IE)
throw new InterruptedException();
else if (interruptMode == REINTERRUPT)
selfInterrupt();
}
}
final boolean transferForSignal(Node node) {
//将节点状态设置为 condition 如果失败,该节点已经被cancel了
if (!compareAndSetWaitStatus(node, Node.CONDITION, 0))
return false;
//入队,返回当前节点的前一个节点
Node p = enq(node);
int ws = p.waitStatus;
//如果前一个节点状态是cancel或者设置signal状态失败,解除该节点的挂起
if (ws > 0 || !compareAndSetWaitStatus(p, ws, Node.SIGNAL))
LockSupport.unpark(node.thread);
return true;
}
//入队,aqs队列
private Node enq(final Node node) {
for (;;) {
Node t = tail;
if (t == null) { // Must initialize
if (compareAndSetHead(new Node()))
tail = head;
} else {
node.prev = t;
if (compareAndSetTail(t, node)) {
t.next = node;
return t;
}
}
}
}
final int fullyRelease(Node node) {
boolean failed = true;
try {
int savedState = getState();
if (release(savedState)) {
failed = false;
return savedState;
} else {
throw new IllegalMonitorStateException();
}
} finally {
if (failed)
node.waitStatus = Node.CANCELLED;
}
}
//释放节点
public final boolean release(int arg) {
//释放节点成功
if (tryRelease(arg)) {
Node h = head;
if (h != null && h.waitStatus != 0)
unparkSuccessor(h);
return true;
}
return false;
}
//将aqs头节点(只是占据锁的标志,并没有线程)的下一个节点解除阻塞,去争抢锁
private void unparkSuccessor(Node node) {
int ws = node.waitStatus;
if (ws < 0)
compareAndSetWaitStatus(node, ws, 0);
Node s = node.next;
if (s == null || s.waitStatus > 0) {
s = null;
for (Node t = tail; t != null && t != node; t = t.prev)
if (t.waitStatus <= 0)
s = t;
}
if (s != null)
LockSupport.unpark(s.thread);
}
final boolean acquireQueued(final Node node, int arg) {
boolean failed = true;
try {
boolean interrupted = false;
for (;;) {
final Node p = node.predecessor();
if (p == head && tryAcquire(arg)) {
setHead(node);
p.next = null; // help GC
failed = false;
return interrupted;
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
interrupted = true;
}
} finally {
if (failed)
cancelAcquire(node);
}
}
static void selfInterrupt() {
Thread.currentThread().interrupt();
}
}
public class ReentrantLock {
abstract static class Sync extends AbstractQueuedSynchronizer {
//释放节点
protected final boolean tryRelease(int releases) {
int c = getState() - releases;
if (Thread.currentThread() != getExclusiveOwnerThread())
throw new IllegalMonitorStateException();
boolean free = false;
if (c == 0) {
free = true;
//将独占锁设置为null,当前aqs没有线程获得锁
setExclusiveOwnerThread(null);
}
setState(c);
return free;
}
}
}