并发编程笔记十:基于AQS的同步工具

十、基于AQS的同步工具

10.1、CountDownLatch倒数器

10.1.1、使用场景

当我们使用多线程处理一个大任务时,我们要把大任务转换为若干个小任务去执行。当这些小任务都执行完成时,我们需要汇总这些小任务的执行结果。所以汇总线程需要等到所有小任务都执行完成之后才能继续执行,这时候我们就可以使用CountDownLatch工具了。

10.1.2、源码分析

CountDownLatch的源码删除注释不到50行。

public class CountDownLatch {
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        Sync(int count) {
            setState(count);
        }
        int getCount() {
            return getState();
        }
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    private final Sync sync;
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
    public void countDown() {
        sync.releaseShared(1);
    }
    public long getCount() {
        return sync.getCount();
    }
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

 

(一)方法介绍:

(1)构造方法:public CountDownLatch(int count)

构造一个CountDownLatch对象,指定倒数器的数值。

(2)public void await() throws InterruptedException

当前线程调用该方法会进入阻塞状态,直到倒数器的数值倒数到0或者被中断。

(3)public boolean await(long timeout, TimeUnit unit)

当前线程调用该方法会进入阻塞状态,直到倒数器的数值倒数到0或者超时。

(4)public void countDown()

倒数器数值减少1

(5)public long getCount()

获取倒数器的当前数值

(二)源码分析:

CountDownLatch的源码非常简洁易懂,他是使用了AQS的共享同步模式,在创建CountDownLatch对象时,指定的计数器值会作为AQS的同步状态state的值。调用await方法相当于获取AQS的同步状态,tryAcquireShared尝试获取共享模式同步状态的方法中判断state的值是否是0,如果是0则获取成功,继续执行代码,否则获取失败,进入阻塞状态,当前线程被封装成NODE加入同步队列中。

当有线程调用countDown方法相当于释放一个同步状态,tryReleaseShared方法中会判断state的值是否为0,如果为0,方法直接返回false,也不需要唤醒同步队列中阻塞的线程。如果不为0,则使用cas的方式将state的值减少1,再返回state的值是否等于0,如果等于0返回的就是true,这时代表释放同步状态成功,就会去唤醒同步队列中等待的线程。

10.1.3、代码使用示例

import java.util.concurrent.CountDownLatch;
public class CountDownLatchDemo {
	public static void main(String[] args) throws Exception{
		int taskName = 10;
		CountDownLatch countDownLatch = new CountDownLatch(taskName);
		for(int i = 0; i < taskName; i++) {
			new Thread(() -> {
				System.out.println(Thread.currentThread().getName() + "正在执行任务");
				try {
					Thread.sleep(1000L);
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
				System.out.println(Thread.currentThread().getName() + "任务执行完毕");
				countDownLatch.countDown();
			}).start();
		}
		System.out.println("统计线程开始等待任务线程");
		countDownLatch.await();
		System.out.println("统计线程开始统计数据");
	}
}

 

执行结果:

统计线程开始等待任务线程
Thread-2正在执行任务
Thread-0正在执行任务
Thread-1正在执行任务
Thread-2任务执行完毕
Thread-0任务执行完毕
Thread-1任务执行完毕
统计线程开始统计数据

 

10.1.4、CountDownLatch和join的比较

CountDownLatch和线程的方法join很类似,但是CountDownLatch更加强大。join方法必须执行线程的线程运行完毕才能停止阻塞,而CountDownLatch的countDown方法在任何时候都可以使用。

当我们使用线程池的时候,就必须使用CountDownLatch了,就不能使用join了。

10.2、CyclicBarrier可重用栅栏

10.2.1、使用场景

当我们需要一组线程互相等待阻塞,直到指定数量的线程到达同一个地点,代码才能继续往下执行时,可以使用CyclicBarrier。

CyclicBarrier在将所有线程释放之后是可以重用的,并且CyclicBarrier可以接受一个Runnable任务,在释放所有等待线程前以同步的方式调用该Runnable。

CyclicBarrier相当于一个大巴车,当车满时才能开车。

10.2.2、源码分析

public class CyclicBarrier {
    private static class Generation {
        boolean broken = false;
    }
    private final ReentrantLock lock = new ReentrantLock();
    private final Condition trip = lock.newCondition();
    private final int parties;
    private final Runnable barrierCommand;
    private Generation generation = new Generation();
    private int count;
    private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }
    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            final Generation g = generation;
            if (g.broken)
                throw new BrokenBarrierException();

            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
            int index = --count;
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }
            // loop until tripped, broken, interrupted, or timed out
            for (;;) {
                try {
                    if (!timed)
                        trip.await();
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        Thread.currentThread().interrupt();
                    }
                }
                if (g.broken)
                    throw new BrokenBarrierException();

                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();
        }
    }
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }
    public CyclicBarrier(int parties) {
        this(parties, null);
    }
    public int getParties() {
        return parties;
    }
    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }
    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }
    public boolean isBroken() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return generation.broken;
        } finally {
            lock.unlock();
        }
    }
    public void reset() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            breakBarrier();   // break the current generation
            nextGeneration(); // start a new generation
        } finally {
            lock.unlock();
        }
    }
    public int getNumberWaiting() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return parties - count;
        } finally {
            lock.unlock();
        }
    }
}

 

(一)方法介绍:

(1)构造方法:public CyclicBarrier(int parties)

创建一个可重用栅栏,指定等待线程的数量

(2)构造方法:public CyclicBarrier(int parties, Runnable barrierAction)

创建一个可重用栅栏,指定等待线程的数量和所有线程都到达栅栏处后先同步执行的任务。

(1)public int getParties()

获取需要指定的等待线程的数量值。

(2)public int await() throws InterruptedException, BrokenBarrierException

调用本方法的线程将等待阻塞在此,直到指定数量的线程到达此处或者线程被中断或者栅栏被破坏。返回一个int值,表示当前线程到达栅栏处时,后续还剩几个线程。也就是,到达的越早,这个线程的await()方法返回值越大,第一个到达的线程的返回值为getParties() – 1

(3)public int await(long timeout, TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException

同上一个方法,可以指定超时时间。

(4)public boolean isBroken()

判断栅栏是否被破坏。在等待的线程中有出现异常或者被中断的线程,那么整个栅栏会被破坏。

(5)public void reset()

调用该方法会重置栅栏。重置栅栏 = 破坏栅栏 + 建立新栅栏。

(6)public int getNumberWaiting()

获取当前栅栏处等待线程的数量值。

(二)源码分析:

CyclicBarrier同步工具利用的是ReentrantLock锁的等待队列实现的功能。创建CyclicBarrier对象时,parties变量记录需要等待线程的数量,使用count记录还剩余多少等待线程,当线程调用await()方法时,会将count值减少1,如果count等于0的话,那么就先同步执行构造方法的Runnable任务,然后唤醒所有(signalAll)等待的线程,然后重置栅栏,方便下次使用。

10.2.3、代码使用示例

import java.util.concurrent.CyclicBarrier;
public class CyclicBarrierDemo {
	public static void main(String[] args) {
		CyclicBarrier cyclicBarrier = new CyclicBarrier(10, () -> {
			System.out.println("人到齐了,司机启动汽车");
		})  ;
		for(int i = 0; i < 100; i++) {
			int index = i + 1;
			new Thread(() -> {
				System.out.println("第" + index + "个人上车");
				try {
					cyclicBarrier.await();
				} catch (Exception e) {
					e.printStackTrace();
				}
				System.out.println("第" + index + "个人出发了");
			}).start();
			try {
				Thread.sleep(100L);
			} catch (InterruptedException e) {
				e.printStackTrace();
			}
		}
	}
}

 

执行结果:

第1个人上车
第2个人上车
……
第9个人上车
第10个人上车
人到齐了,司机启动汽车
第10个人出发了
第1个人出发了
……
第9个人出发了
第7个人出发了
第11个人上车
第12个人上车
……
第19个人上车
第20个人上车
人到齐了,司机启动汽车
第20个人出发了
第11个人出发了
……
第19个人出发了
第16个人出发了
……
……

 

10.3、Semaphore信号量

10.3.1、使用场景

Semaphore一般用来限制线程同时执行的数量,如果执行的线程超过了这个数量值,那么后来的线程会被阻塞住,直到有线程被释放。

每个线程在执行前都需要获取一个许可证,只有获取成功了才能继续往下执行,否则会被阻塞。

10.3.2、源码分析

package atomic;
public class Semaphore implements java.io.Serializable {
    private static final long serialVersionUID = -3222578661600680210L;
    private final Sync sync;
    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;
        Sync(int permits) {
            setState(permits);
        }
        final int getPermits() {
            return getState();
        }
        final int nonfairTryAcquireShared(int acquires) {
            for (;;) {
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
        protected final boolean tryReleaseShared(int releases) {
            for (;;) {
                int current = getState();
                int next = current + releases;
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                if (compareAndSetState(current, next))
                    return true;
            }
        }
        final void reducePermits(int reductions) {
            for (;;) {
                int current = getState();
                int next = current - reductions;
                if (next > current) // underflow
                    throw new Error("Permit count underflow");
                if (compareAndSetState(current, next))
                    return;
            }
        }

        final int drainPermits() {
            for (;;) {
                int current = getState();
                if (current == 0 || compareAndSetState(current, 0))
                    return current;
            }
        }
    }
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;

        NonfairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;

        FairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            for (;;) {
                if (hasQueuedPredecessors())
                    return -1;
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
    }
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    public void acquireUninterruptibly() {
        sync.acquireShared(1);
    }
    public boolean tryAcquire() {
        return sync.nonfairTryAcquireShared(1) >= 0;
    }
    public boolean tryAcquire(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
    public void release() {
        sync.releaseShared(1);
    }
    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }
    public void acquireUninterruptibly(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireShared(permits);
    }
    public boolean tryAcquire(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.nonfairTryAcquireShared(permits) >= 0;
    }
    public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
        throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
    }
    public void release(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.releaseShared(permits);
    }
    public int availablePermits() {
        return sync.getPermits();
    }
    public int drainPermits() {
        return sync.drainPermits();
    }
    protected void reducePermits(int reduction) {
        if (reduction < 0) throw new IllegalArgumentException();
        sync.reducePermits(reduction);
    }
    public boolean isFair() {
        return sync instanceof FairSync;
    }
    public final boolean hasQueuedThreads() {
        return sync.hasQueuedThreads();
    }
    public final int getQueueLength() {
        return sync.getQueueLength();
    }
    protected Collection<Thread> getQueuedThreads() {
        return sync.getQueuedThreads();
    }
    public String toString() {
        return super.toString() + "[Permits = " + sync.getPermits() + "]";
    }
}

 

(一)方法介绍

(1)构造方法:public Semaphore(int permits)

创建一个Semaphore对象,指定许可证的数量。

(2)构造方法:public Semaphore(int permits, boolean fair)

创建一个Semaphore对象,指定许可证的数量和锁的公平性。

(3)public void acquire() throws InterruptedException

获取一个许可证,一直阻塞直到获取成功或者线程被中断。

(4)public void acquireUninterruptibly()

获取一个许可证,一直阻塞直到获取成功。

(5)public boolean tryAcquire()

获取一个许可证,立即返回,获取成功返回true,获取失败返回false。

(6)public void release()

释放一个许可证。

(7)public void acquire(int permits) throws InterruptedException

获取指定数量的许可证,一直阻塞直到获取成功或者线程被中断。

(8)public void acquireUninterruptibly(int permits)

获取指定数量的许可证,一直阻塞直到获取成功。

(9)public boolean tryAcquire(int permits)

获取指定数量的许可证,立即返回,获取成功返回true,获取失败返回false。

(10)public boolean tryAcquire(int permits, long timeout, TimeUnit unit)

获取指定数量的许可证,一直阻塞直到获取成功或者线程被中断或者线程超时。

(11)public void release(int permits)

释放指定数量的许可证。

(12)public int availablePermits()

获取剩余的许可证数量。

(13)public int drainPermits()

获取剩下所有的许可证。

(14)protected void reducePermits(int reduction)

减少许可证的数量。

(15)public boolean isFair()

Semaphore是否是公平的。

(二)、源码分析

Semaphore的源码比较简单,和9.4.2的SharedAQSDemo代码类似,利用了AQS共享模式的同步工具来实现的。

10.3.3、代码使用示例

类似于9.4.2的测试代码:

package aqstool;

import java.util.concurrent.Semaphore;

public class SemaphoreTest {
	public static void main(String[] args) {
		Semaphore semaphore = new Semaphore(5);
		for(int i = 0; i < 20; i++) {
			new Thread(() -> {
				try {
					semaphore.acquire();
					System.out.println(Thread.currentThread().getName() + "正在执行!");
					Thread.sleep(3000L);
				} catch (Exception e) {
					e.printStackTrace();
				}finally {
					semaphore.release();
				}
			}, String.valueOf(i)).start();;
		}
	}
}

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值