文章目录
同步锁的本质 —— 排队
同步的方式:独享锁 - 单个队列窗口,共享锁 - 多个队列窗口;
抢锁的方式:插队抢(非公平锁)、先来后到抢(公平锁);
没捡到锁的处理方式:快速尝试多次(CAS自旋锁)、阻塞等待;
唤醒阻塞线程的方式(叫号器):全部通知、通知下一个;
资源占用流程
抽象队列同步器AbstractQueuedSynchronizer
手写AQS
import java.util.Iterator;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.LockSupport;
/**
* 抽象队列同步器
* state, owner, waiters
*/
public class MyAqs {
// 1、 如何判断一个资源的拥有者
public volatile AtomicReference<Thread> owner = new AtomicReference<>();
// 保存 正在等待的线程
public volatile LinkedBlockingQueue<Thread> waiters = new LinkedBlockingQueue<>();
// 记录资源状态
public volatile AtomicInteger state = new AtomicInteger(0);
/**
* 定义了资源争用的逻辑,如果没拿到,则等待
*/
public void acquire() {
boolean addQ = true;
while (!tryAcquire()) {
if (addQ) {
// 没拿到锁,加入到等待集合
waiters.offer(Thread.currentThread());
addQ = false;
} else {
// 阻塞 挂起当前的线程,不要继续往下跑了
LockSupport.park(); // 伪唤醒,就是非unpark唤醒的
}
}
waiters.remove(Thread.currentThread()); // 把线程移除
}
/**
* 共享资源占用的逻辑,返回资源的占用情况
*/
public boolean tryAcquire() { // 交给使用者去实现。 模板方法设计模式
throw new UnsupportedOperationException();
}
/**
* 定义释放资源的逻辑,释放之后,通知后续节点进行争抢
*/
public void release() {
if (tryRelease()) {
// 通知等待者
Iterator<Thread> iterator = waiters.iterator();
while (iterator.hasNext()) {
Thread next = iterator.next();
LockSupport.unpark(next); // 唤醒
}
}
}
/**
* 实际执行资源释放的操作,具体的AQS使用者去实现
*/
public boolean tryRelease() {
throw new UnsupportedOperationException();
}
public AtomicInteger getState() {
return state;
}
public void setState(AtomicInteger state) {
this.state = state;
}
}
AQS源码解析
要理解AQS,需重点阅读 java.util.concurrent.locks包下的JDK源码;
也可点击阅读并发-AQS源码分析
信号量Semaphore
手写Semaphore信号量
import com.aqs.MyAqs;
/**
* 自定义的信号量实现
*/
public class MySemaphore {
MyAqs aqs = new MyAqs() {
@Override
public int tryAcquireShared() { // 信号量获取, 数量 - 1
for(;;) {
int count = getState().get();
int n = count - 1;
if(count <= 0 || n < 0) {
return -1;
}
if(getState().compareAndSet(count, n)) {
return 1;
}
}
}
@Override
public boolean tryReleaseShared() { // state + 1
return getState().incrementAndGet() >= 0;
}
};
/** 许可数量 */
public MySemaphore(int count) {
aqs.getState().set(count); // 设置资源的状态
}
/** 获取令牌 */
public void acquire() {
aqs.acquireShared();
}
/** 释放令牌 */
public void release() {
aqs.releaseShared();
}
}
信号量用法代码示例:
import java.util.Random;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 信号量机制
*/
public class SemaphoreDemo {
public static void main(String[] args) {
SemaphoreDemo semaphoreTest = new SemaphoreDemo();
int N = 9; // 客人数量
Semaphore semaphore = new Semaphore(5); // 手牌数量,限制请求数量
for (int i = 0; i < N; i++) {
String vipNo = "vip-00" + i;
new Thread(() -> {
try {
semaphore.acquire(); // 获取令牌
semaphoreTest.service(vipNo);
semaphore.release(); // 释放令牌
} catch (InterruptedException e) {
e.printStackTrace();
}
}).start();
}
}
// 限流 控制5个线程 同时访问
public void service(String vipNo) throws InterruptedException {
System.out.println("楼上出来迎接贵宾一位,贵宾编号" + vipNo + ",...");
Thread.sleep(new Random().nextInt(3000));
System.out.println("欢送贵宾出门,贵宾编号" + vipNo);
}
}
Semaphore源码解析
import java.util.Collection;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
// Semaphore 源码解析版本
public class SemaphoreSource {
private final SemaphoreSource.Sync sync;
// 还是AQS的机制
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
Sync(int permits) {
setState(permits);
}
final int getPermits() {
return getState();
}
final int nonfairTryAcquireShared(int acquires) {
for (; ; ) {
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining)) {
return remaining;
}
}
}
@Override
protected final boolean tryReleaseShared(int releases) {
for (; ; ) {
int current = getState();
int next = current + releases;
if (next < current){ // overflow
throw new Error("Maximum permit count exceeded");
}
if (compareAndSetState(current, next)) {
return true;
}
}
}
final void reducePermits(int reductions) {
for (; ; ) {
int current = getState();
int next = current - reductions;
if (next > current) { // overflow
throw new Error("Permit count underflow");
}
if (compareAndSetState(current, next)) {
return;
}
}
}
final int drainPermits() {
for (; ; ) {
int current = getState();
if (current == 0 || compareAndSetState(current, 0)) {
return current;
}
}
}
}
static final class NonfairSync extends SemaphoreSource.Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
@Override
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
static final class FairSync extends SemaphoreSource.Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
@Override
protected int tryAcquireShared(int acquires) {
for (; ; ) {
if (hasQueuedPredecessors()) {
return -1;
}
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining)) {
return remaining;
}
}
}
}
public SemaphoreSource(int permits) {
sync = new SemaphoreSource.NonfairSync(permits);
}
public SemaphoreSource(int permits, boolean fair) {
sync = fair ? new SemaphoreSource.FairSync(permits) : new SemaphoreSource.NonfairSync(permits);
}
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public void acquireUninterruptibly() {
sync.acquireShared(1);
}
public boolean tryAcquire() {
return sync.nonfairTryAcquireShared(1) >= 0;
}
public boolean tryAcquire(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void release() {
sync.releaseShared(1);
}
public void acquire(int permits) throws InterruptedException {
if (permits < 0) {
throw new IllegalArgumentException();
}
sync.acquireSharedInterruptibly(permits);
}
public void acquireUninterruptibly(int permits) {
if (permits < 0) {
throw new IllegalArgumentException();
}
sync.acquireShared(permits);
}
public boolean tryAcquire(int permits) {
if (permits < 0) {
throw new IllegalArgumentException();
}
return sync.nonfairTryAcquireShared(permits) >= 0;
}
public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
throws InterruptedException {
if (permits < 0) {
throw new IllegalArgumentException();
}
return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
}
public void release(int permits) {
if (permits < 0) {
throw new IllegalArgumentException();
}
sync.releaseShared(permits);
}
public int availablePermits() {
return sync.getPermits();
}
public int drainPermits() {
return sync.drainPermits();
}
protected void reducePermits(int reduction) {
if (reduction < 0) {
throw new IllegalArgumentException();
}
sync.reducePermits(reduction);
}
public boolean isFair() {
return sync instanceof SemaphoreSource.FairSync;
}
public final boolean hasQueuedThreads() {
return sync.hasQueuedThreads();
}
public final int getQueueLength() {
return sync.getQueueLength();
}
protected Collection<Thread> getQueuedThreads() {
return sync.getQueuedThreads();
}
@Override
public String toString() {
return super.toString() + "[Permits = " + sync.getPermits() + "]";
}
}
倒计数器CountDownLatch
类似比赛开始前,裁判等待运动员就绪后,吹哨开始比赛;
手写CountDownLatch倒计数器
import com.study.lock.aqs.MyAqs;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
/**
* CountDownLatch 自己实现
*/
public class CountDownLatchDemo {
MyAqs aqs = new MyAqs() {
@Override
public int tryAcquireShared() { // 如果非等于0,代表当前还有线程没准备就绪,则认为需要等待
return this.getState().get() == 0 ? 1 : -1;
}
@Override
public boolean tryReleaseShared() { // 如果非等于0,代表当前还有线程没准备就绪,则不会通知继续执行
return this.getState().decrementAndGet() == 0;
}
};
public CountDownLatchDemo(int count) {
aqs.setState(new AtomicInteger(count));
}
public void await() {
aqs.acquireShared();
}
public void countDown() {
aqs.releaseShared();
}
public static void main(String[] args) throws InterruptedException {
// 一个请求,后台需要调用多个接口 查询数据
CountDownLatch cdLdemo = new CountDownLatch(10); // 创建,计数数值
for (int i = 0; i < 10; i++) { // 启动九个线程,最后一个两秒后启动
int finalI = i;
new Thread(() -> {
try {
Thread.sleep(2000L);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("我是" + Thread.currentThread() + ".我执行接口-" + finalI +"调用了");
cdLdemo.countDown(); // 参与计数
// 不影响后续操作
}).start();
}
cdLdemo.await(); // 等待计数器为0
System.out.println("全部执行完毕.我来召唤神龙");
}
}
CountDownLatch源码解析
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
// CountDownLatch 源码解析版
public class CountDownLatchSource {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
@Override
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
@Override
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0) {
return false;
}
int nextc = c-1;
if (compareAndSetState(c, nextc)) {
return nextc == 0;
}
}
}
}
private final CountDownLatchSource.Sync sync;
public CountDownLatchSource(int count) {
if (count < 0) {
throw new IllegalArgumentException("count < 0");
}
this.sync = new CountDownLatchSource.Sync(count);
}
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void countDown() {
sync.releaseShared(1);
}
public long getCount() {
return sync.getCount();
}
// public String toString();
}
栅栏CyclicBarrier
手写CyclcBarrier栅栏
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.LinkedBlockingQueue;
// 循环屏障(栅栏),示例:数据库批量插入
// 游戏大厅... 5人组队打副本
public class CyclicBarrierDemo {
public static void main(String[] args) throws InterruptedException {
LinkedBlockingQueue<String> sqls = new LinkedBlockingQueue<>();
// 任务1+2+3...1000 拆分为100个任务(1+..10, 11+20) -> 100线程去处理。
// 每当有4个线程处于await状态的时候,则会触发barrierAction执行
CyclicBarrier barrier = new CyclicBarrier(4, new Runnable() {
@Override
public void run() {
// 这是每满足4次数据库操作,就触发一次批量执行
System.out.println("有4个线程执行了,开始批量插入: " + Thread.currentThread());
for (int i = 0; i < 4; i++) {
System.out.println(sqls.poll());
}
}
});
for (int i = 0; i < 10; i++) {
new Thread(() -> {
try {
sqls.add("data - " + Thread.currentThread()); // 缓存起来
Thread.sleep(1000L); // 模拟数据库操作耗时
barrier.await(); // 等待栅栏打开,有4个线程都执行到这段代码的时候,才会继续往下执行
System.out.println(Thread.currentThread() + "插入完毕");
} catch (Exception e) {
e.printStackTrace();
}
}).start();
}
Thread.sleep(2000);
}
}
CyclcBarrier源码解析
// CyclicBarrier 源码解析版
public class CyclicBarrier {
private static class Generation {
boolean broken = false;
}
private final ReentrantLock lock = new ReentrantLock();
private final Condition trip = lock.newCondition();
private final int parties;
private final Runnable barrierCommand;
private Generation generation = new Generation();
private int count;
private void nextGeneration() {
trip.signalAll(); // 唤醒线程
count = parties; // count重置
generation = new Generation();
}
/** 跳出栅栏,标记破碎了 */
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock(); // 一个线程拿到锁
try {
final Generation g = generation;
if (g.broken) // 一个报错,大家都报错
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier(); // 如果线程被中断,唤醒目前正在等待的线程。
throw new InterruptedException();
}
int index = --count;
if (index == 0) { // tripped 数量够了
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run(); // 触发执行指定的任务
ranAction = true;
nextGeneration(); // 唤醒等待的线程继续执行。重新计数
return 0;
} finally {
if (!ranAction) // 执行出现异常,不为true则设置
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
for (;;) {
try {
if (!timed) // 如果没设置超时时间,就直接进入等待
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos); // 进入带有超时机制的等待
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {// 中途线程被中断,则标记
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) { // 等待超时
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) {
throw new IllegalArgumentException();
}
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
public CyclicBarrier(int parties) {
this(parties, null);
}
public int getParties() {
return parties;
}
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}