JAVA多线程(五)—— CyclicBarrier和CountDownLatch

CyclicBarrier

Cyclic是循环的意思,Barrier是屏障的意思,所以从其中文意思上可以理解为一个循环屏障,也就是说当满足一个条件之前,都是等待状态,只有当所有线程都满足了,在进行下一步操作。

举个例子就是:我们去游乐场坐过山车,当座位都满了的时候,工作人员才会发车,要不然所有人都在等待发车(不考虑游客少的时候,人不满也会发车)。

使用方法

首先我们先来看看它的使用方法:

package com.example.demo.basis.thread.lock;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

/**
 * @author : sun
 * create at:  2020/5/14  21:03
 * @description: CyclicBarrier
 * 使用方法:CyclicBarrier有两个构造,两个构造参数是:前面的参数表示线程目标数,当await的线程到达这个数的时候,执行后面Runnable
 * 里面的run方法,同时执行每个线程接下来的代码;一个参数就是:指定目标线程数,达到这个数后没有任何操作.
 * 注意事项:await的时当前线程;await会造成阻塞
 */
public class TestCyclicBarrier {
    static CyclicBarrier cyclicBarrier = new CyclicBarrier(10,
            new Runnable() {
                @Override
                public void run() {
                    System.out.println("有10个人了,发车");
                }
            });

    public static void main(String[] args) throws BrokenBarrierException, InterruptedException {
        for (int i = 0; i < 100; i++) {
            int a = i;
            new Thread(()->{
                try {
                    cyclicBarrier.await();
                    System.out.println("执行到 : " + a);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (BrokenBarrierException e) {
                    e.printStackTrace();
                }
            }).start();
        }

        /**
         * 下面的代码是不起作用的,因为它只有主线程await了,所以使用的时候需要注意
         */
//        for (int i = 0; i < 100; i++) {
//            cyclicBarrier.await();
//        }

    }
}

执行结果我这边就不放了,有兴趣的可以copy出来执行一下。

await()方法会阻塞当前线程;
当等待的线程数满足阈值后就会执行相应的方法,且继续执行被阻塞线程里的代码;
这是一个循环,如果将for里面的100替换成101,你会发现程序会一直处于运行状态;

源码分析

package java.util.concurrent;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 首先从源码可以看出,CyclicBarrier不是基于AQS
 **/
public class CyclicBarrier {
    /**
	 * 年代类——对循环进行分代,同一代共用一个标识位;在CyclicBarrier中,同一批线程属于同一代
     */
    private static class Generation {
        boolean broken = false;
    }

    //锁
    private final ReentrantLock lock = new ReentrantLock();
    /** Condition to wait on until tripped */
    private final Condition trip = lock.newCondition();
    /** The number of parties */
    private final int parties;
    /* The command to run when tripped */
    private final Runnable barrierCommand;
    /** The current generation */
    private Generation generation = new Generation();

    /**
     * 计算线程数
     */
    private int count;

    /**
	 * 重置,开启下一代
     */
    private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }

    /**
	 * 如果有线程中断,则将这一代标识修改为true,同时唤醒condition中的线程
     */
    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }

    /**
	 * 这个是它最主要的方法
     */
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        //锁上
        lock.lock();
        try {
            //获得年代
            final Generation g = generation;
            
            //判断当前generation的状态,如果这个年代的中有些线程已经将CyclicBarrie“打断”,则抛出异常
            if (g.broken)
                throw new BrokenBarrierException();

            //这里就是校验线程状态,如果线程状态处于中断,则调用breakBarrier方法,将CyclicBarrie的年代标识设置为true
            //对应上面的校验
            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
			//count--
            int index = --count;
            //当count=0的时候,说明可以打开阻塞触发barrierCommand任务了
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    //这个方法里会唤醒trip中所有等待的线程,同时重置generation——下一代
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // 如果count不为0,自旋
            for (;;) {
                try {
                    //这里涉及Condition知识,在我其他文章中有介绍
                    //如果没有设置等待时间,则调用Condition.await()方法
                    if (!timed)
                        trip.await();
                    else if (nanos > 0L)//如果设置了超时时间,调用Condition.awaitNanos()方法
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        Thread.currentThread().interrupt();
                    }
                }

                if (g.broken)
                    throw new BrokenBarrierException();
				//generation已经更新,返回index
                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();//释放锁
        }
    }

    /**
	 * 构造方法,初始化参数
     */
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

    /**
	 * 构造方法,调用上面的构造
     */
    public CyclicBarrier(int parties) {
        this(parties, null);
    }

    /**
	 * 获得设置的阈值——栅栏数
     */
    public int getParties() {
        return parties;
    }

    /**
     * 给外部调用的await方法,没有超时时间
     */
    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    /**
     * 给外部调用的await方法,有超时时间
     */
    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }

    /**
     * 查看当前代是否被中断
     */
    public boolean isBroken() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return generation.broken;
        } finally {
            lock.unlock();
        }
    }

    /**
     * 重置barrier到初始化状态
     */
    public void reset() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            breakBarrier();   // break the current generation
            nextGeneration(); // start a new generation
        } finally {
            lock.unlock();
        }
    }

    /**
     * 获得等待的线程数
     */
    public int getNumberWaiting() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return parties - count;
        } finally {
            lock.unlock();
        }
    }
}

CountDownLatch

CountDownLatch是JUC包里提供的一个并发编程的工具类,它的实现也是基于AQS,其内部通过state来计数,当这个数变成0(查看它的构造函数,初始值不可以小于0,否则抛出异常)等待的线程就可以继续执行。

使用方法

首先我们先来看看它的使用方法:

package com.example.demo.basis.thread.lock;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * @author : sun
 * create at:  2020/5/14  19:32
 * @description: CountDownLatch
 * 使用方法:首先初始化是指定大小,可以理解为指定座位数,一个线程执行完业务步骤后调用countDown方法,当前线程就类似于做到座位上了,
 * 在主线程里调用await方法,意思就是等座位上的所有线程执行完毕后,执行下面的代码,这里会造成阻塞。
 * 在使用的时候需要注意:在指定CountDownLatch初始参数的时候一定要预判好有几个位置,因为只有位置坐满了,并且每个位置的线程都执行
 * 结束,才会执行await后面的代码。
 */
public class TestCountDownLatch {
    static CountDownLatch countDownLatch = new CountDownLatch(3);

    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();
        new Thread(()->{
            try {
                TimeUnit.SECONDS.sleep(5);
                countDownLatch.countDown();
                System.out.println("1");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        },"线程1").start();

        new Thread(()->{
            try {
                TimeUnit.SECONDS.sleep(8);
                countDownLatch.countDown();
                System.out.println("2");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        },"线程2").start();

        new Thread(()->{
            try {
                TimeUnit.SECONDS.sleep(10);
                countDownLatch.countDown();
                System.out.println("3");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        },"线程3").start();


        new Thread(()->{
            try {
                TimeUnit.SECONDS.sleep(20);
                countDownLatch.countDown();
                System.out.println("4");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        },"线程4").start();

        new Thread(()->{
            try {

                //这里试验了CountDownLatch是可以唤醒多个等待的线程的
                countDownLatch.await();
                System.out.println(Thread.currentThread().getName() + " 执行到了 !");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        },"线程5").start();

        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        long endTime = System.currentTimeMillis();

        System.out.println("一共耗时 : " + (endTime-startTime));

    }
}

结果:

1
2
3
一共耗时 : 10029
线程5 执行到了 !
4

我们从结果来分析:

CountDownLatch.countDown()方法不会阻塞当前线程,只会阻塞调用await()方法的地方。
CountDownLatch是不可重入的,只会生效一次
CountDownLatch会唤醒多个await()方法的地方

源码分析

package java.util.concurrent;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;

public class CountDownLatch {
    /**
     * 首先这个内部类实现了AbstractQueuedSynchronizer,说明CountDownLatch也是基于AQS实现的
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        //构造方法,调用的setState方法是父类的final类型的方法,主要是设置state值
        Sync(int count) {
            setState(count);
        }

        //获得state值,getState也是父类final类型的方法
        int getCount() {
            return getState();
        }

        //wait方法最终会调用到这个方法,来获得state;如果为0返回1,其他返回-1。
        //注意:入参没有用到
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        //核心方法——自旋,验证state
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    //CountDownLatch的属性
    private final Sync sync;

    /**
     * 这个CountDownLatch的唯一一个构造函数,参数为int类型
     * 必须指定一个初始count,也就是目标计数,当达到这个目标值后,才会放开阻塞
     */
    public CountDownLatch(int count) {
        //如果count小于0,抛出异常
        if (count < 0) throw new IllegalArgumentException("count < 0");
        //初始化Sync,跟踪代码这里最终设置了父类的state这个值,state是volatile修饰的,
        this.sync = new Sync(count);
    }

    /**
	 * 首先这个方法会调用AbstractQueuedSynchronizer的acquireSharedInterruptibly,入参为1,仔细阅读源码发现,这个1其实没有用到
	 * 下面是acquireSharedInterruptibly 方法:
	 * public final void acquireSharedInterruptibly(int arg)
     *       throws InterruptedException {
     *	判断当前线程状态,这个线程是调用await方法的线程!
     *  if (Thread.interrupted())
     *      throw new InterruptedException();
     *	这里调用的是tryAcquireShared方法,回到子类sync的tryAcquireShared方法(参考上面Sync内的tryAcquireShared)
     *	tryAcquireShared(arg),其实就是判断state这个值,如果为0返回1,其他返回-1。
     *	返回1,简单来说就是不需要阻塞了,直接执行await后面的代码;
     *	返回-1,则回去调用doAcquireSharedInterruptibly方法,这个方法简单的来说就是把当前线程封装成一个node,放到一个线程节点队列里,然后自旋
     *	来调用tryAcquireShared方法,这里是从头部节点一个一个来判断的,当头节点通过,才会轮到下一个节点,直到所有线程节点都执行完成
     *  if (tryAcquireShared(arg) < 0)
     *      doAcquireSharedInterruptibly(arg);
     *	}
     */
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
	 * 这个方法和await方法类似,只是它加了一个超时时间限制,第一个参数是时间,第二个参数是时间单位
	 * 如果超过这个时间就不会继续等待了,其实源码和上述的doAcquireSharedInterruptibly中,自旋里面多了个是时间判断
	 * 超过这个时间就会跳出自旋,不会让当前线程继续等待了。
     */
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    /**
	 * 调用AbstractQueuedSynchronizer的releaseShared方法:
	 * public final boolean releaseShared(int arg) {
     *      if (tryReleaseShared(arg)) {
     			//对于CountDownLatch只有当state最终为0的时候才会执行doReleaseShared方法;
     			//只用当state为0,才需要去唤醒等待的线程,所以doReleaseShared概述的来讲就是去唤醒等待的线程队列中的线程
     *          doReleaseShared();
     *          return true;
     *   	}
     *   	return false;
     *	}
     * 由上述方法可以看到会调用子类的tryReleaseShared方法,Sync的tryReleaseShared参考上面说明.
     *
     * 需要注意的是countDown方法没有加锁!
     */
    public void countDown() {
        sync.releaseShared(1);
    }

    /**
     * 获得state值.
     */
    public long getCount() {
        return sync.getCount();
    }

    /**
     * toString方法
     */
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

对于CountDownLatch的使用场景我这里有个例子:我们项目中对于大文件的下载就会通过CountDownLatch来操作。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值