Semaphore介绍
Semaphore(信号量)是JAVA多线程中的一个工具类,它可以通过指定参数来控制执行线程数量,一般用于限流访问某个资源时使用。
Semaphore使用示例
需求场景:用一个核心线程数为6,最大线程数为20的线程池执行任务,但是要求最多只能同时运行3个线程
代码:
public class demo {
//创建线程池,核心线程数:6;最大线程数:20;时间:5;时间单位:秒;阻塞队列:ArrayBlockingQueue,最大容量为10;线程工厂:默认;拒绝策略:默认
static ThreadPoolExecutor poolExecutor = new ThreadPoolExecutor(6, 20, 5, TimeUnit.SECONDS, new ArrayBlockingQueue<>(10));
public static void main(String[] args) throws InterruptedException {
Semaphore semaphore = new Semaphore(3);//指定线程数量
for (int i = 0; i < 10; i++) {
poolExecutor.execute(new Runnable() {
@Override
public void run() {
try {
semaphore.acquire();
System.out.println(Thread.currentThread().getName() + " start...");
Thread.sleep(2000);
//用来表明当前线程结束
System.out.println(Thread.currentThread().getName() + " end...");
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
semaphore.release();
}
}
});
}
poolExecutor.shutdown();
}
}
(结果分析)从输出结果可以看出:最多只能同时开启3个线程(创建Semaphore时指定线程数量),只有等最先开启的三个线程中的某个结束了才会开启新的线程,但同时运行的总量始终保持在3个以内!
pool-1-thread-1 start...
pool-1-thread-2 start...
pool-1-thread-3 start...
pool-1-thread-2 end...
pool-1-thread-3 end...
pool-1-thread-4 start...
pool-1-thread-1 end...
pool-1-thread-5 start...
pool-1-thread-6 start...
pool-1-thread-5 end...
pool-1-thread-6 end...
pool-1-thread-4 end...
pool-1-thread-6 start...
pool-1-thread-3 start...
pool-1-thread-2 start...
pool-1-thread-6 end...
pool-1-thread-2 end...
pool-1-thread-3 end...
pool-1-thread-1 start...
pool-1-thread-1 end...
Process finished with exit code 0
Semaphore实现原理
源码:
public class Semaphore implements java.io.Serializable {
private static final long serialVersionUID = -3222578661600680210L;
//继承AQS的内部类
private final Sync sync;
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
//构造函数,传入的参数为信号量permits
Sync(int permits) {
setState(permits);
}
//获取信号量
final int getPermits() {
return getState();
}
//以非公平的方式尝试获取信号量
final int nonfairTryAcquireShared(int acquires) {
//自旋
for (; ; ) {
//当前信号量
int available = getState();
//获取acquires个信号量后的剩余信号量
int remaining = available - acquires;
//如果剩余信号量小于0(获取失败),或者成功把剩余信号量更新为当前信号量(获取成功)都会退出自旋并返回剩余信号量
if (remaining < 0 || compareAndSetState(available, remaining))
return remaining;
}
}
//尝试释放信号量
protected final boolean tryReleaseShared(int releases) {
for (; ; ) {
//当前信号量
int current = getState();
//下个信号量,即当前信号量+释放的信号量(线程运行结束将信号量还给Semaphore,所以相加)
int next = current + releases;
//如果下个信号量小于当前信号量则有越界的情况,报错!
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
//如果没问题就用CAS更新当前信号量,并结束自旋
if (compareAndSetState(current, next))
return true;
}
}
//减少信号量
final void reducePermits(int reductions) {
for (; ; ) {
int current = getState();
//下个信号量为当前信号量-减少信号量
int next = current - reductions;
//如果没有那么多可减少的信号则抛出异常
if (next > current) // underflow
throw new Error("Permit count underflow");
//如果没问题就更新信号量并结束自旋
if (compareAndSetState(current, next))
return;
}
}
//清空信号量
final int drainPermits() {
for (; ; ) {
int current = getState();
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
//非公平
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
//公平
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
//公平的方式尝试获取信号量
protected int tryAcquireShared(int acquires) {
for (; ; ) {
//如果队列当前节点已经有任务则结束自旋
if (hasQueuedPredecessors())
return -1;
//当前信号量
int available = getState();
//剩余信号量=当先信号量-获取信号量
int remaining = available - acquires;
//如果剩余信号量小于0(获取失败),或者成功把剩余信号量更新为当前信号量(获取成功)都会退出自旋并返回剩余信号量
if (remaining < 0 || compareAndSetState(available, remaining))
return remaining;
}
}
}
//构造方法,默认以非公平的方式设置信号量
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
//构造方法,自定义以公平还是非公平的方式设置信号量
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
//获取信号量,如果当前线程已中止(interrupted)就抛出异常
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
//获取信号量,无论当前线程是否中止都尝试获取
public void acquireUninterruptibly() {
sync.acquireShared(1);
}
//尝试获取信号量
public boolean tryAcquire() {
return sync.nonfairTryAcquireShared(1) >= 0;
}
//尝试获取信号量并设置超时时间
public boolean tryAcquire(long timeout, TimeUnit unit) throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
//释放信号量
public void release() {
sync.releaseShared(1);
}
//获取指定数量的信号量
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
//获取指定数量的信号量
public void acquireUninterruptibly(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireShared(permits);
}
//尝试获取指定数量的信号量
public boolean tryAcquire(int permits) {
if (permits < 0) throw new IllegalArgumentException();
return sync.nonfairTryAcquireShared(permits) >= 0;
}
//尝试获取指定数量的信号量,并设置超时时间
public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
}
//释放指定数量的信号量
public void release(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.releaseShared(permits);
}
//获取当前信号数量
public int availablePermits() {
return sync.getPermits();
}
//清空信号量
public int drainPermits() {
return sync.drainPermits();
}
//减少指定数量的信号量
protected void reducePermits(int reduction) {
if (reduction < 0) throw new IllegalArgumentException();
sync.reducePermits(reduction);
}
//判断是否为公平
public boolean isFair() {
return sync instanceof FairSync;
}
//判断是否有队列(有阻塞线程时才会产生队列,即判断是否有阻塞线程)
public final boolean hasQueuedThreads() {
return sync.hasQueuedThreads();
}
//获取阻塞线程数量
public final int getQueueLength() {
return sync.getQueueLength();
}
//获取阻塞线程并封装成集合返回
protected Collection<Thread> getQueuedThreads() {
return sync.getQueuedThreads();
}
public String toString() {
return super.toString() + "[Permits = " + sync.getPermits() + "]";
}
}
核心方法:
1、acquire()
//获取信号量,如果当前线程已中止(interrupted)就抛出异常
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
//如果当前线程已经终止则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
//尝试获取信号量(调用内部类Sync的方法)
if (tryAcquireShared(arg) < 0)
//获取信号量失败时会将当前线程封装成node加入到阻塞队列中
doAcquireSharedInterruptibly(arg);
}
2、release()
//释放信号量
public void release() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
//调用内部类Sync尝试释放信号量
if (tryReleaseShared(arg)) {
//释放成功后唤醒阻塞队列的next节点
doReleaseShared();
return true;
}
return false;
}