JUC 工具类之 CountDownLatch 详解

本文详细介绍了Java并发工具CountDownLatch的使用,包括通过join方法和CountDownLatch实现线程同步,并深入剖析了CountDownLatch的原理,如await和countDown方法的工作流程。此外,还给出了CountDownLatch在多线程任务协调中的应用场景,例如等待所有子任务完成后再执行主任务。
摘要由CSDN通过智能技术生成

一 提出一个问题

如何实现让主线程等所有子线程执行完了后,主要线程再继续执行?即如何实现一个线程等其他线程执行完了后再继续执行?

1.1 join 解决方案

在前面的文章中我们介绍了 Thread 类的 join 方法,join 的工作原理是,不停检查 thread 是否存活,如果存活则让当前线程永远 wait,直到 thread 线程终止,线程的 notifyAll 就会被调用。

下面我们就使用 join 来实现上面的问题。

import java.util.Random;
import java.util.concurrent.CountDownLatch;

public class CountDownLatchDemo {

    public static void main(String[] args) {
        System.out.println("主要线程开始等待其他子线程执行");
        try {
            test();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    public static void test() throws InterruptedException {
       Thread thread1 = new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + 
            " 线程开始");
            Random random = new Random();
            try {
                Thread.sleep(random.nextInt(10000) + 1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println( Thread.currentThread().getName() + 
            " 线程执行完毕");

        },"线程1");
       thread1.start();
        Thread thread2 = new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + 
            " 线程开始");
            Random random = new Random();
            try {
                Thread.sleep(random.nextInt(10000) + 1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + 
            " 线程执行完毕");

        },"线程2");
        thread2.start();
        Thread thread3 = new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + 
            " 线程开始");
            Random random = new Random();
            try {
                Thread.sleep(random.nextInt(10000) + 1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println( Thread.currentThread().getName() + 
            " 线程执行完毕");

        },"线程3");
        thread3.start();
        Thread thread4 = new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + 
            " 线程开始");
            Random random = new Random();
            try {
                Thread.sleep(random.nextInt(10000) + 1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + 
            " 线程执行完毕");

        },"线程4");
        thread4.start();
        //启动了四个线程,然后让四个线程一直检测自己是否已经结束
        thread1.join();
        thread2.join();
        thread3.join();
        thread4.join();
        System.out.println("主线程继续执行");
        //todo 业务代码
    }
}

运行结果:

主要线程开始等待其他子线程执行
线程1 线程开始
线程2 线程开始
线程3 线程开始
线程4 线程开始
线程3 线程执行完毕
线程2 线程执行完毕
线程1 线程执行完毕
线程4 线程执行完毕
主线程继续执行

主线程继续干活是要等前面四个线程全部执行完毕后再继续的。但是这么搞有点麻烦,那就是每个线程都得调用 join 方法,有没有更好玩的的呢?

答案是有的,它来了。

它就是 JUC 下面的一个很牛逼的并发工具类 CountDownLatch。是 JDK1.5 的时候有的,言外之意就是在 JDK1.5 之前就只能用 join 方法了。

1.2 CountDownLatch 解决方案

CountDownLatch 中我们主要用到两个方法一个是 await() 方法,调用这个方法的线程会被阻塞,另外一个是 countDown() 方法,调用这个方法会使计数器减一,当计数器的值为 0 时,因调用 await() 方法被阻塞的线程会被唤醒,继续执行。请看代码:

import java.util.Random;
import java.util.concurrent.CountDownLatch;

public class CountDownLatchDemo {

    public static void main(String[] args) {
        System.out.println("主要线程开始等待其他子线程执行");
        test();
    }

    public static void test() {
        int threadCount = 5;
        CountDownLatch countDownLatch = 
                new CountDownLatch(threadCount);
        for (int i = 0; i < threadCount; i++) {
            final int finalI = i + 1;
            new Thread(() -> {
                System.out.println("第 " + finalI + " 线程开始");
                Random random = new Random();
                try {
                    Thread.sleep(random.nextInt(10000) + 1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println("第 " + finalI + " 线程执行完毕");

                countDownLatch.countDown();
            }).start();
        }

        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(threadCount + " 个线程全部执行完毕");
        System.out.println("主线程继续执行");
        //todo业务代码
    }
}

输出:

主要线程开始等待其他子线程执行
第 1 线程开始
第 2 线程开始
第 3 线程开始
第 4 线程开始
第 5 线程开始
第 1 线程执行完毕
第 2 线程执行完毕
第 5 线程执行完毕
第 4 线程执行完毕
第 3 线程执行完毕
5 个线程全部执行完毕
主线程继续执行

二 CountDownLatch 原理

2.1 CountDownLatch 概念

CountDownLatch 是一个计数器闭锁,通过它可以完成类似于阻塞当前线程的功能,即:一个线程或多个线程一直等待,直到其他线程执行的操作完成。

CountDownLatch 定义了一个计数器,和一个阻塞队列, 当计数器的值递减为0之前,阻塞队列里面的线程处于挂起状态,当计数器递减到0时会唤醒阻塞队列所有线程,这里的计数器是一个标志,可以表示一个任务一个线程,也可以表示一个倒计时器,CountDownLatch 可以解决那些一个或者多个线程在执行之前必须依赖于某些必要的前提业务先执行的场景。

在这里插入图片描述

2.2 常用方法

2.2.1 构造方法

我们在上面的案例中:

int threadCount = 5;
 CountDownLatch countDownLatch = new CountDownLatch(threadCount);

有用到 new CountDownLatch(threadCount);来创建一个 CountDownLatch 实例对象。我们看看这个构造方法:

private final Sync sync;
public CountDownLatch(int count) { 
    //记者count值不能小于0
    if (count < 0) throw new IllegalArgumentException("count < 0");
    //创建一个Sync实例对象入参就是count
    this.sync = new Sync(count);
}

然后这里有个内部类 Sync,这个 Sync 内部类也没几行代码,Sync 继承了 AbstractQueuedSynchronizer 抽象队列同步器(以下简称 AQS)。

private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 
              4982264981922014374L;
        //入参count
        Sync(int count) {
            //这个setState方法还记得否?就是上篇文章AQS中的setState()方法
            //就是给AQS中的state赋值,state=count
            setState(count);
        }
        //获取AQS中state的值
        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
        //死循环
        protected boolean tryReleaseShared(int releases) {
            for (;;) {
                //获取AQS中的state
                int c = getState();
                //如果AQS中的state==0,就返回false
                if (c == 0)  return false;
                int nextc = c-1;
                //nextc=state-1
                //
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
 }

2.2.2 countDown 方法

public void countDown() {
    // 调用的就是 AQS 中的方法
    sync.releaseShared(1);
}

AQS 中 releaseShared 方法

public final boolean releaseShared(int arg) {
    // arg 为固定值 1
    // 如果计数器state 为0 返回true,前提是调用 countDown() 之前不能已经为0
    //tryReleaseShared在AQS是空方法
    if (tryReleaseShared(arg)) {
      // 唤醒等待队列的线程
        doReleaseShared();
         return true;
    }
    return false;
}
protected boolean tryReleaseShared(int arg) {
   throw new UnsupportedOperationException();
}

这个方法 tryReleaseShared() 是在 CountDownLatch 中内部类 Sync 中实现的:

//其实也没什么新招
//还是死循环+CAS配合 实现计数器state减1
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)  return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
           return nextc == 0;
     }
}

方法 doReleaseShared 却是 AQS 实现的(因为 CountDownLatch 和其内部类都没有实现,只能是 AQS 实现了)。

//实现思路就是从头到尾的遍历列队中所有的节点为shared状态的
private void doReleaseShared() {
        //死循环
        for (;;) {
            //获取当前列队的头节点
            Node h = head;
            //列队中可能为空列队,也有可能只有一个node节点
            if (h != null && h != tail) {
                //获取头节点的状态
                int ws = h.waitStatus;
                //如果头节点为SIGNAL状态, 说明后继节点需要唤醒
                if (ws == Node.SIGNAL) {
                    //将头结点的waitstatue设置为0,
                    // 以后就不会再次唤醒后继节点了。
                    //这一步是为了解决并发问题,
                    // 保证只unpark一次!!不成功就继续
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    //(释放)唤醒头节点的后继节点
                    unparkSuccessor(h);
                }// 状态为0并且不成功,继续
                else if (ws == 0 && !compareAndSetWaitStatus(
                    h, 0, Node.PROPAGATE))
                    continue;// loop on failed CAS
            }
            // 若头结点改变,继续循环  
            if (h == head) // loop if head changed
                break;
        }
}

整个调用逻辑大致为:
在这里插入图片描述

2.2.3 await 方法

public void await() throws InterruptedException {
   sync.acquireSharedInterruptibly(1);
}

调用 AQS 中的:

public final void acquireSharedInterruptibly(int arg) throws 
        InterruptedException {
      //判断是否被中断过
      if (Thread.interrupted()) throw new InterruptedException();
      //如果state不等于0的时候
      if (tryAcquireShared(arg) < 0){
            doAcquireSharedInterruptibly(arg);
      }
}

其中方法 tryAcquireShared(arg) 是 CountDownLatch 的内部类 Sync 的 tryAcquireShared 方法。

protected int tryAcquireShared(int acquires) {
  //判断AQS中的state是否已经等于0了,等于翻译1否则返回-1
  return (getState() == 0) ? 1 : -1;
}

再调用 AQS 中的 doAcquireSharedInterruptibly 方法

//这个方法就是将当前线程封装成node节点加入到列队中,
// 并判断是否需要阻塞当前线程,这个节点都会被设置成shared状态
 // 这样做的目的时当state值为0时会唤醒所有shared的节点
private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        //这个方法应该很熟悉了吧,前面的文章都介绍过,
        // 将当前线程封装成节点加入到列队中
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            //(又是死循环)一直执行,直到获取锁,返回
            for (;;) {
                //获取前驱节点
                final Node p = node.predecessor();
                //前驱节点为头结点
                if (p == head) {
                    //所以再次尝试获取信号量,这就是上面分析的那个获取方法
                    int r = tryAcquireShared(arg);
                    //如果r大于0证明获取信号量获取成功了证明
                    // 下一个可以获取信号量的线程是当前线程
                    if (r >= 0) {
                        //将当前节点变成列队的head节点然后返回
                        setHeadAndPropagate(node, r);
                        //方便GC
                        p.next = null; 
                        failed = false;
                        return;
                    }
                }
          //判断是否要进入阻塞状态.如果shouldParkAfterFailedAcquire方法
               //返回true,表示需要进入阻塞 调用parkAndCheckInterrupt
               // 否则表示还可以再次尝试获取锁,继续进行for循环
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            //失败就放弃
            if (failed){
                cancelAcquire(node);
            }
        }
}

方法 shouldParkAfterFailedAcquire 是 AQS 的

//p是前驱结点,node是当前结点
private static boolean shouldParkAfterFailedAcquire(
        Node pred, Node node) {
    int ws = pred.waitStatus; //获取前驱节点的状态
    if (ws == Node.SIGNAL) //表明前驱节点可以运行
        return true;
    if (ws > 0) { //如果前驱节点状态大于0表明已经中断,
        do {
            node.prev = pred = pred.prev; 
        } while (pred.waitStatus > 0);
        pred.next = node;
    } else {
        //等于0进入这里
        compareAndSetWaitStatus(pred, ws, Node.SIGNAL); 
    }
    //只有前节点状态为NodeSIGNAL才返回真
    return false; 
}

我们对 shouldParkAfterFailedAcquire 来进行一个整体的概述,首先应该明白节点的状态,节点的状态是为了表明当前线程的良好度,如果当前线程被打断了,在唤醒的过程中是不是应该忽略该线程

 static final class Node {
        static final int CANCELLED =  1;
        static final int SIGNAL    = -1;
        static final int CONDITION = -2;
        static final int PROPAGATE = -3;
       //....

目前你只需知道大于 0 时表明该线程已近被取消,已近是无效节点,不应该被唤醒,注意:初始化链头节点时头节点状态值为 0。

shouldParkAfterFailedAcquire 是位于无限 for 循环内的,这一点需要注意一般每个节点都会经历两次循环后然后被阻塞。

在 AQS 的 doAcquireSharedInterruptibly 中可能会再次调用 CountDownLatch 的内部类 Sync 的 tryAcquireShared 方法和 AQS 的 setHeadAndPropagate 方法。

setHeadAndPropagate 方法源码如下。

private void setHeadAndPropagate(Node node, int propagate) {
        // 获取头结点
        Node h = head; 
        // 设置头结点
        setHead(node);
        // 进行判断
        if (propagate > 0 || h == null || h.waitStatus < 0 ||
            (h = head) == null || h.waitStatus < 0) {
            // 获取节点的后继
            Node s = node.next;
            if (s == null || s.isShared()) // 后继为空或者为共享模式
                // 以共享模式进行释放
                doReleaseShared();
        }
    }

该方法设置头结点并且释放头结点后面的满足条件的结点,该方法中可能会调用到 AQS 的 doReleaseShared 方法,其源码如下。

private void doReleaseShared() {
        // 无限循环
        for (;;) {
            // 保存头结点
            Node h = head;
            if (h != null && h != tail) { 
            // 头结点不为空并且头结点不为尾结点
                // 获取头结点的等待状态
                int ws = h.waitStatus; 
                if (ws == Node.SIGNAL) { // 状态为SIGNAL
                // 不成功就继续
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) 
                        continue;            // loop to recheck cases
                    // 释放后继结点
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(
                         h, 0, Node.PROPAGATE)) 
                         // 状态为0并且不成功,继续
                    continue;                // loop on failed CAS
            }
            if (h == head) // 若头结点改变,继续循环  
                break;
        }
    }

CountDownLatch 的 await 调用大致会有如下的调用链:
在这里插入图片描述

三 使用场景

CountDownLatch 的一个非常典型的应用场景是:有一个任务想要往下执行,但必须要等到其他的任务执行完毕后才可以继续往下执行。假如我们这个想要继续往下执行的任务调用一个 CountDownLatch 对象的 await() 方法,其他的任务执行完自己的任务后调用同一个 CountDownLatch 对象上的 countDown() 方法,这个调用 await() 方法的任务将一直阻塞等待,直到这个 CountDownLatch 对象的计数值减到 0 为止。

案例1

举个例子,有三个工人在为老板干活,这个老板有一个习惯,就是当三个工人把一天的活都干完了的时候,他就来检查所有工人所干的活。记住这个条件:三个工人先全部干完活,老板才检查。

案例2

比如读取 excel 表格,需要把 execl 表格中的多个 sheet 进行数据汇总,为了提高汇总的效率我们一般会开启多个线程,每个线程去读取一个 sheet,可是线程执行是异步的,我们不知道什么时候数据处理结束了。那么这个时候我们就可以运用 CountDownLatch,有几个 sheet 就把 state 初始化几。每个线程执行完就调用countDown() 方法,在汇总的地方加上 await() 方法,当所有线程执行完了,就可以进行数据的汇总了。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值