CountDownLatch使用、源码解析及与Thread.join()对比

JDK从1.5版本开始提供CountDownLatch工具类,它能使一个县城等待其他线程各自完成工作后再执行,CountDownLatch内部是通过一个计数器实现的,计数器的初始值是批量任务初始线程的数量,每当一个线程完成任务后,计数器的值就会减1,当计数器的值为0时,唤醒所有被阻塞的线程。

一、CountDownLatch使用

首先,看下CountDownLatch的几个主要方法:

public class CountDownLatch {
...
	// 阻塞当前线程,直到count值为0或者线程被中断
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
	
	// 阻塞当前线程,直到count值为0或者线程被中断或者超出等待时间
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

	// count值减1,如果count为0则唤醒所有被阻塞的线程
    public void countDown() {
        sync.releaseShared(1);
    }

	// 返回当前的count值
    public long getCount() {
        return sync.getCount();
    }
...
}

CountDownLatch使用范例如下:

public class CountDownLatchTest {

    public static void main(String[] args) throws InterruptedException {
        int jobSize = 5;
        CountDownLatch startLatch = new CountDownLatch(1);
        CountDownLatch endLatch = new CountDownLatch(jobSize);
        ExecutorService exec = Executors.newCachedThreadPool();
        for(int i=0; i < jobSize; i++) {
            exec.submit(new Runnable() {
                @Override
                public void run() {
                    try {
                        startLatch.await();
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    try {
                        Thread.sleep(1000);
                        System.out.println(Thread.currentThread().getName());
                    } catch (Exception e){
                        e.printStackTrace();
                    } finally {
                        endLatch.countDown();
                    }
                }
            });
        }
        long startTime = System.currentTimeMillis();
        startLatch.countDown();
        endLatch.await(2, TimeUnit.SECONDS);
        long endTime = System.currentTimeMillis();
        System.out.println("cost time : " + (endTime - startTime));
        exec.shutdown();
    }
}
---
pool-1-thread-5
pool-1-thread-4
pool-1-thread-3
pool-1-thread-1
pool-1-thread-2
cost time : 1004
二、CountDownLatch源码解析

CountDownLatch的成员变量、构造器和内部类实现如下,其结构非常简单,其唯一成员变量是一个同步器,继承自AbstractQueuedSynchronizer。

public class CountDownLatch {

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    
    public void countDown() {
        sync.releaseShared(1);
    }

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
	...    
}

CountDownLatch调用countDown()方法时会调用的releaseShared方法将count值减1;调用await()方法时会调用acquireSharedInterruptibly,其核心逻辑如下:当CountDownLatch同步器count不等于0时,会调用doAcquireSharedInterruptibly方法,将当前线程封装为Node并阻塞,直到同步器count为0或被中断。

	public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0) // count!=0
            doAcquireSharedInterruptibly(arg);
    }
    
    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

调用await(long timeout, TimeUnit unit)方法时会调用tryAcquireSharedNanos方法,在CountDownLatch的count不等于0时调用doAcquireSharedNanos方法,该方法和doAcquireSharedInterruptibly的逻辑基本一致,只是多了截止时间,进入方法后首先计算等待截止时间的判断,当前时间超过截止时间时直接返回false,值得注意的是这一句代码: nanosTimeout > spinForTimeoutThreshold,其含义是当前时间距离截止时间小于spinForTimeoutThreshold时不阻塞线程,让线程在程序中自旋,自旋时间spinForTimeoutThreshold被默认是指为1000ns。

    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout);
    }

    private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (nanosTimeout <= 0L)
            return false;
        final long deadline = System.nanoTime() + nanosTimeout;
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return true;
                    }
                }
                nanosTimeout = deadline - System.nanoTime();
                if (nanosTimeout <= 0L)
                    return false;
                if (shouldParkAfterFailedAcquire(p, node) &&
                    nanosTimeout > spinForTimeoutThreshold)
                    LockSupport.parkNanos(this, nanosTimeout);
                if (Thread.interrupted())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }
    
    static final long spinForTimeoutThreshold = 1000L;

具体AbstractQueuedSynchronizer共享式获取或释放解析见: AbstractQueuedSynchronizer共享式获取或释放

三、CountDownLatch与Thread.join区别

在线程中调用Thread的join()方法,也可以实现阻塞当前线程的效果,Thread的join()方法调用后不断检查线程是否存活,如果存活则继续阻塞。join方法核心代码如下:

   public final synchronized void join(long millis)
    throws InterruptedException {
        long base = System.currentTimeMillis();
        long now = 0;

        if (millis < 0) {
            throw new IllegalArgumentException("timeout value is negative");
        }

        if (millis == 0) {
            while (isAlive()) {
                wait(0);
            }
        } else {
            while (isAlive()) {
                long delay = millis - now;
                if (delay <= 0) {
                    break;
                }
                wait(delay);
                now = System.currentTimeMillis() - base;
            }
        }
    }

用join()方法实现类似CountDownLatch测试范例的逻辑,示例如下:

public class ThreadJoinTest {

    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();
        List<MyTask> taskList = new ArrayList<>();
        for(int i=0; i < 5; i++) {
            MyTask task = new MyTask();
            taskList.add(task);
            task.start();
        }
        for(MyTask task : taskList) {
            try {
                task.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        long endTime = System.currentTimeMillis();
        System.out.println("cost time : " + (endTime - startTime));
    }

    public static class MyTask extends Thread {
        @Override
        public void run() {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName());
        }
    }
}
---
Thread-0
Thread-2
Thread-1
Thread-3
Thread-4
cost time : 1005

可见,Thread的join()方法也可以实现类似逻辑,那CountDownLatch与Thread.join()方法的区别是什么?

区别:Thread.join()方法依赖于线程的存活情况,等所线程执行完毕时才能往下执行,而CountDownLatch提供计数器的功能,更加灵活,只需监测计数器count值为0就可继续往下执行,与线程执行情况可以解耦。

参考
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值