Java - 多线程之 -- CountDownLatch

开发中最常见的场景,在主线程中开启多线程并执行任务,主线程需要等待所有子线程执行完毕后再进行处理的场景。

在CountDownLatch出现之前,一般都是使用线程的join()方法来实现,但是join不够灵活,不能够满足不同场景的需要,所以JDK后来提供了CountDownLatch,用于同步。

一、CountDownLatch使用实例

public class JoinCountDownLatch
{
    // 创建一个CountDownLatch实例
    private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
    
    public static void main(String[] args) throws InterruptedException
    {
        Thread threadOne = new Thread(new Runnable()
        {
            @Override
            public void run()
            {
                try
                {
                    Thread.sleep(1000);
                    System.out.println("child threadOne over!");
                }
                catch (InterruptedException e)
                {
                    e.printStackTrace();
                }
                finally
                {
                    countDownLatch.countDown();
                }
            }
        });
        
        Thread threadTwo = new Thread(new Runnable()
        {
            @Override
            public void run()
            {
                try
                {
                    Thread.sleep(1000);
                    System.out.println("child threadTwo over!");
                    
                }
                catch (InterruptedException e)
                {
                    e.printStackTrace();
                }
                finally
                {
                    countDownLatch.countDown();
                }
                
            }
        });
        
        // 启动子线程
        threadOne.start();
        threadTwo.start();
        
        System.out.println("wait all child thread over!");
        
        // 等待子线程执行完毕,返回
        countDownLatch.await();
        
        System.out.println("all child thread over!");
    }
}

运行输出结果:

wait all child thread over!
child threadTwo over!
child threadOne over!
all child thread over!

创建了CountDownLatch实例,因为有两个子线程,所以构造函数参数传递为2,主线程调用countDownLatch.await()方法后会被阻塞。

子线程执行完毕后调用countDownLatch.countDown()方法让countDownLatch内部的计数器减一,等所有子线程执行完毕调用countDown()后计数器会变为0,这时候主线程的await()才会返回。

二、CountDownLatch实现原理

CountDownLatch内部有个计数器,并且这个计数器是递减的。

public class CountDownLatch
{
    /**
     * AQS
     */
    private static final class Sync extends AbstractQueuedSynchronizer
    {
        private static final long serialVersionUID = 4982264981922014374L;
        
        Sync(int count)
        {
            setState(count);
        }
        
        int getCount()
        {
            return getState();
        }
        
        protected int tryAcquireShared(int acquires)
        {
            return (getState() == 0) ? 1 : -1;
        }
        
        protected boolean tryReleaseShared(int releases)
        {
            // Decrement count; signal when transition to zero
            for (; ; )
            {
                int c = getState();
                if (c == 0)
                {
                    return false;
                }
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                {
                    return nextc == 0;
                }
            }
        }
    }
    
    //基于AQS的同步器
    private final Sync sync;
    
    //等待计数为0
    public void await() throws InterruptedException
    {
        sync.acquireSharedInterruptibly(1);
    }
    
    //计数递减1
    public void countDown()
    {
        sync.releaseShared(1);
    }
}

A、构造方法

public CountDownLatch(int count)
{
    if (count < 0)
    {
        throw new IllegalArgumentException("count < 0");
    }
    this.sync = new Sync(count);
}

Sync(int count)
{
    setState(count);
}

CountDownLatch内部是使用AQS实现的,通过构造函数初始化了计数器的值,可知实际上是把计数器的值赋值给了AQS的状态值state,也即用AQS的状态值来表示计数器值。

B、void await()方法

当线程调用了CountDownLatch对象的await方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:

  1. 当所有线程都调用了CountDownLatch对象的countDown方法后,也即计时器值为0时;
  2. 其它线程调用了当前线程的interrupt()方法中断了当前线程,当前线程会抛出InterruptedException异常后返回。

await()方法内部:

// AQS的获取共享资源时候可被中断的方法
public final void acquireSharedInterruptibly(int arg) throws InterruptedException
{
    // 如果线程被中断则抛异常
    if (Thread.interrupted())
    {
        throw new InterruptedException();
    }
    // 尝试看当前是否计数值为0,为0则直接返回,否者进入AQS的队列等待
    if (tryAcquireShared(arg) < 0)
    {
        doAcquireSharedInterruptibly(arg);
    }
}

// sync类实现的AQS的接口
protected int tryAcquireShared(int acquires)
{
    return (getState() == 0) ? 1 : -1;
}

// CountDownLatch的await()方法
public void await() throws InterruptedException
{
    sync.acquireSharedInterruptibly(1);
}

await()方法委托sync调用了AQS的acquireSharedInterruptibly方法。

acquireSharedInterruptibly方法的特点是线程获取资源的时候可以被中断,并且获取的资源是共享资源。方法内部首先判断,如果当前线程被中断了则抛出异常,否则调用sync实现的tryAcquireShared方法看当前状态值(计数器值)是否为0,是则当前线程的await()方法直接返回,否则调用AQS的doAcquireSharedInterruptibly让当前线程阻塞。

另外可知,tryAcquireShared传递的arg参数是没有用到的,调用tryAcquireShared的方法仅仅是检查当前状态值是不是0,并没有调用CAS让当前状态值减去1。

C、boolean await(long timeout, TimeUnit unit)方法

当线程调用了CountDownLatch对象的该方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:

  1. 当所有线程都调用了CountDownLatch对象的countDown方法后,也就是计时器值为0的时候,返回true;
  2. 设置的timeout时间到了,因为超时而返回false;
  3. 其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程会抛出InterruptedException异常后返回。
public boolean await(long timeout, TimeUnit unit) throws InterruptedException
{
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

D、void countDown()方法

当线程调用了该方法后,会递减计数器的值,递减后,如果计数器为0则会唤醒所有调用await方法而被阻塞的线程,否则什么都不做。

// syn的方法
protected boolean tryReleaseShared(int releases)
{
    // 循环进行cas,直到当前线程成功完成cas使计数值(状态值state)减一并更新到state
    for (; ; )
    {
        int c = getState();
        
        // 如果当前状态值为0则直接返回A
        if (c == 0)
        {
            return false;
        }
        
        // CAS设置计数值减一B
        int nextc = c - 1;
        if (compareAndSetState(c, nextc))
        {
            return nextc == 0;
        }
    }
}

// AQS的方法
public final boolean releaseShared(int arg)
{
    // 调用sync实现的tryReleaseShared
    if (tryReleaseShared(arg))
    {
        // AQS的释放资源方法
        doReleaseShared();
        return true;
    }
    return false;
}

// CountDownLatch的countDown()方法
public void countDown()
{
    // 委托sync调用AQS的方法
    sync.releaseShared(1);
}

CountDownLatch的countDown()方法是委托sync调用了AQS的releaseShared方法。releaseShared内部首先调用了sync实现的AQS的tryReleaseShared。

该方法获取当前状态值(计数器值),执行代码A,如果当前状态值为0则直接返回false,countDown()方法直接返回;否则执行代码B使用CAS设置计数器减一,CAS失败则循环重试,否则如果当前计数器为0则返回true,返回true后说明当前线程是最后一个线程调用的countdown方法,那么该线程除了让计数器值减1外,还需要唤醒调用CountDownLatch的await方法而被阻塞的线程,也就是AQS的doReleaseShared()方法。

添加代码A,是为了防止当计数器值为0后,其他线程又调用了countDown方法,如果没有代码A,状态值就会变成负数了。

E、long getCount()方法

获取当前计数器的值,也就是AQS的state的值,一般在debug测试的时候使用。

public long getCount()
{
    return sync.getCount();
}

int getCount()
{
    return getState();
}

内部调用AQS的getState方法来获取state的值。

三、使用CountDownLatch注意事项

使用CountDownLatch,要注意countDown()方法要在finally块内执行,避免抛异常后得不到执行。
private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);

public static void main(String[] args) throws InterruptedException
{
    Thread threadOne = new Thread(new Runnable()
    {
        @Override
        public void run()
        {
            
            try
            {
                Thread.sleep(1000);
                System.out.println("child threadOne over!");
                throw new RuntimeException("error");
            }
            catch (InterruptedException e)
            {
                e.printStackTrace();
            }
            
            countDownLatch.countDown();
        }
    });
    
    Thread threadTwo = new Thread(new Runnable()
    {
        @Override
        public void run()
        {
            
            try
            {
                Thread.sleep(1000);
                System.out.println("child threadTwo over!");
            }
            catch (InterruptedException e)
            {
                e.printStackTrace();
            }
            countDownLatch.countDown();
        }
    });
    
    // 启动子线程
    threadOne.start();
    threadTwo.start();
    
    System.out.println("wait all child thread over!");
    
    // 等待子线程执行完毕,返回
    countDownLatch.await();
    
    System.out.println("all child thread over!");
}

代码执行结果:

wait all child thread over!
Exception in thread "Thread-0" child threadOne over!
child threadTwo over!
java.lang.RuntimeException: error
	at top.cfish.java.ext.CountDownLatchDemo$1.run(CountDownLatchDemo.java:26)
	at java.lang.Thread.run(Thread.java:748)

以上代码创建countDownLatch,计数器初始化为2,然后开启了两个线程,main线程调用了countDownLatch.await()会挂起,直到计数器为0时才返回。

其中threadTwo线程休眠1s后调用了countDownLatch.countDown()让计数器变为了1,而threadOne在执行countDownLatch.countDown()前抛出异常后退出了,没有让计数器变为0。所以main线程会一直阻塞到countDownLatch.await()不会返回。

所以规范规定调用countDownLatch.countDown()要在finally块内执行。

private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);

public static void main(String[] args) throws InterruptedException
{
    Thread threadOne = new Thread(new Runnable()
    {
        @Override
        public void run()
        {
            
            try
            {
                Thread.sleep(1000);
                System.out.println("child threadOne over!");
                throw new RuntimeException("error");
            }
            catch (InterruptedException e)
            {
                e.printStackTrace();
            }
            finally
            {
                countDownLatch.countDown();
            }
        }
    });
    
    Thread threadTwo = new Thread(new Runnable()
    {
        @Override
        public void run()
        {
            
            try
            {
                Thread.sleep(1000);
                System.out.println("child threadTwo over!");
            }
            catch (InterruptedException e)
            {
                e.printStackTrace();
            }
            finally
            {
                countDownLatch.countDown();
            }
        }
    });
    
    // 启动子线程
    threadOne.start();
    threadTwo.start();
    
    System.out.println("wait all child thread over!");
    
    // 等待子线程执行完毕,返回
    countDownLatch.await();
    
    System.out.println("all child thread over!");
}

代码执行结果:

wait all child thread over!
Exception in thread "Thread-0" java.lang.RuntimeException: error
	at top.cfish.java.ext.Main2$1.run(Main2.java:26)
	at java.lang.Thread.run(Thread.java:748)
child threadTwo over!
child threadOne over!
all child thread over!

countDownLatch.countDown()被放到了finally块内,即使threadOne抛出了异常,由于finally块内代码总是会执行,所以运行上面的代码,main线程会正常退出。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值