CountDownLatch背后的原理

前言:

在日常工作中常用到多线程,如果使用多线程处理那么就要考虑同步问题,一般我们会考虑使用加锁来解决。但是还有一些场景,如下:
场景:小升初考试,考生做题,监考老师要等待所有考生交试卷后才可以离开,那么把考生比作多个线程,老师比作主线程。

使用join来实现上述场景:

public void testThread() throws InterruptedException {
    ConcurrentSkipListSet concurrentSkipListSet = new ConcurrentSkipListSet();

    for (int i=0; i< 10; i++) {
        final int t = i;
        Thread  thread = new Thread(new Runnable() {
            @Override
            public void run() {
                String name = Thread.currentThread().getName();
                System.out.println(name);
                ThreadUtils.doSleep(1000L * t);
                concurrentSkipListSet.add("cz" + t + "--" + name);
            }

        });
        thread.start();
        thread.join();
    }
    System.out.println(concurrentSkipListSet.size());

    for (Object s : concurrentSkipListSet) {
        System.out.println(s.toString());
    }

}

耗时情况:

image.png
join的意思让当前线程陷入等待,主线程启动了thread1后陷入等待,然后等待thread1执行完毕继续启动thread2,然后等待thread2执行完毕继续启动下一个,其实使用了join之后多线程已经成为了串行执行了。

使用CountDownLatch来实现上述场景

public static void testCountDownLatch() throws InterruptedException {
    TimeLag lag = new TimeLag();
    CountDownLatch countDownLatch = new CountDownLatch(10);
    ConcurrentSkipListSet concurrentSkipListSet = new ConcurrentSkipListSet();

    for (int i=0; i< 10; i++) {
        final int t = i;
        Thread  thread = new Thread(new Runnable() {
            @Override
            public void run() {
                String name = Thread.currentThread().getName();
                System.out.println(name);
                ThreadUtils.doSleep(1000L * t);
                concurrentSkipListSet.add("cz" + t + "--" + name);
            }

        });
        thread.start();
    }
    countDownLatch.await();

    System.out.println(concurrentSkipListSet.size());

    for (Object s : concurrentSkipListSet) {
        System.out.println(s.toString());
    }

    System.out.println(lag.cost());
}

执行效果如下:

image.png

这次很明显是多线程并行执行的。效率上远远高于使用join,而且也达到了想要的效果。

下面来看下CountDownLatch是什么原理。

image.png

通过api文档可以了解基础用法,要洞悉原理还需要仔细阅读源码和相关文档。
源码中有一个重要的概念就是AQS(AbstractQueuedLongSynchronizer)AQS是java中一个同步器的实现方案,java中除去synchronized关键字之外,其他锁的实现基本上都是基于AQS。

AQS实现原理可以简单的理解为是:
程序实现了一链表,该链表对所有线程可见(使用volatile修饰),线程去竞争锁,竞争到了修改锁状态,竞争失败的要依次排队,最终形成链表。加锁时往链表尾部添加节点,解锁时将头节点删除。这样就形成了一个加锁解锁机制。AQS还提供了共享锁和独占锁的实现

CountDownLatch就是基于AQS来实现的。内部维护一个Sync内部类来继承AQS实现了共享锁的加锁解锁。

AQS

AQS中调用了一个重要的类Unsafe,该类中的方法是一些native修饰的方法,大家都知道用native修饰的方法是调用底层c++实现的,可以和底层硬件交互,Unsafe提供了一个系类重要的方法就是compareAndSwapXXX(XXX包括Object,Int,Long),这个方法可以保证多个线程修改同一个值而不出现并发问题。AQS通过调用Unsafe中的compareAndSwapxxx方法保证了锁的状态不被篡改。

下面模仿AQS原理实现一个自己的锁:

package com.cz.lock;

import pers.cz.utils.LogUtils;
import sun.misc.Unsafe;

import java.lang.reflect.Field;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.LockSupport;
import java.util.stream.Collectors;

/**
 * @program: Reids
 * @description:
 * @author: Cheng Zhi
 * @create: 2023-02-22 11:06
 **/
public class JefSimpleLock {

    private static JefLogLevel logLevel = JefLogLevel.ERROR;

    private static boolean useUnsafe = true;

    /**
     * 内部类通过反射获取Unsafe
     */
    public static class JefUnsafe {
        public static Unsafe getJefUnsage() {
            try {
                Field field = Unsafe.class.getDeclaredField("theUnsafe");
                field.setAccessible(true);
                return (Unsafe) field.get(null);
            } catch (Exception e) {
                LogUtils.error(e.getMessage(), e);
            }
            return null;
        }
    }
    /**
     * 使用Unsafe来实现CAS,因为普通类Unsafe使用时必须使用主类加载器加载,否则会抛出异常java.lang.SecurityException: Unsafe
     *  所以private static final Unsafe unsafe = Unsafe.getUnsafe();这种写法是不对的,要通过反射获取
     */
    private static final Unsafe unsafe = JefUnsafe.getJefUnsage();

    /**
     * 维护一个队列用来存储需要等待的线程
     */
    private static Queue<Thread> threadQueue = new ConcurrentLinkedQueue<Thread>();

    /**
     * 记录当前持有锁并且在运行的线程
     */
    private static Thread currentRunThread;

    /**
     * 定义一个变量做为锁的标识, state: 0表示当前锁未被持有,1表示当前锁已被持有
     */
    private static int state = 0;

    private static final long stateOffset;

    static {
        try {
            stateOffset = unsafe.objectFieldOffset // 获取非静态属性Field在对象实例中的偏移量,读写对象的非静态属性时会用到这个偏移量(对象中的地址)
                    (AbstractQueuedSynchronizer.class.getDeclaredField("state"));
        } catch (NoSuchFieldException e) {
            throw new Error(e);
        }
    }
    /**
     * 尝试获取锁,如果队列中存在线程,说明现在有线程在等待锁,
     * @return
     */
    private boolean tryGetLock() {

        int lockState = state;
        if (threadQueue.size() > 0) {
            if (logLevel.equals(JefLogLevel.DEBUG)) {
                LogUtils.debug("当前等待线程:" + threadQueue.stream().map(p -> p.getName()).collect(Collectors.toList()));
            }
            // 如果当前线程和队列中的队头线程
            Thread thread = Thread.currentThread();
            if (lockState == 0 && thread == threadQueue.peek()) {
                // 如果当前锁不被人持有,并且当前线程是队头,那么这个线程可以运行,
                if (useUnsafe) {
                    compareAndSwap(0, 1) ;
                } else {
                    state = 1;
                }
                // 如果获取到锁,则记录该线程
                currentRunThread = thread;
                if (logLevel.equals(JefLogLevel.DEBUG)) {
                    LogUtils.debug("获取锁成功:" + currentRunThread.getName());
                }
                return true;
            }
        }

        return false;
    }

    /**
     * 加锁,这里加锁的原理就是为了让其他线程等待,只有一个线程运行
     */
    public void lock() {

        // 如果获取到锁,则直接运行。
        if (tryGetLock()) {
            return;
        }
        Thread thread = Thread.currentThread();
        if (logLevel.equals(JefLogLevel.DEBUG)) {
            LogUtils.debug(thread.getName() + "当前线程状态为:" + thread.getState());
        }
        // 如果没有获取到锁,则把这个线程放到队列里去排队里去,然后循环等待
        threadQueue.add(thread);

        while(true) {
            if (tryGetLock()) {
                threadQueue.poll();
                return;
            }
            // 如果获取不到锁则阻塞,等待循环去获取锁。
            LockSupport.park(thread);
        }
    }
    /**
     * 解锁
     */
    public void unlock() {
        Thread currentThread = Thread.currentThread();
        if (currentThread == currentRunThread) {
            // 如果当前线程为持有锁的线程,则释放锁,修改状态
            if (useUnsafe) {
                compareAndSwap(1, 0) ;
            } else {
                state = 0;
            }
            currentRunThread = null;

            Thread peek = threadQueue.peek();
            if (peek != null) {
                // 唤醒队头线程让它去争抢锁
                LockSupport.unpark(peek);
            }
        }
    }

    /**
     * 使用CAS来做状态变更,保证state值不被篡改
     * @param current
     * @param update
     */
    private void compareAndSwap(int current, int update) {
        // 四个参数:
        unsafe.compareAndSwapInt(this, stateOffset, current, update);
    }

    /**
     * 设置日志级别
     * @param logLevel
     */
    public void setLogLevel(JefLogLevel logLevel) {
        this.logLevel = logLevel;
    }

    /**
     * 设置是否使用CAS机制
     * @param isUseUnsafe
     */
    public void isUseUnsafe(boolean isUseUnsafe) {
        useUnsafe = isUseUnsafe;
    }
}

enum JefLogLevel {
    DEBUG,
    INFO,
    WARE,
    ERROR;
}

实现思路:
锁对象内部维护一个队列并且维护一个锁标识。
加锁:如果获取锁成功则修改标识为1,如果获取失败则将线程添加到队列中,然后循环去获取锁(这里用到了一个LockSupport类,该类提供了阻塞线程的功能)。
解锁:解锁时,如果判断当前持有锁的线程和要解锁的线程是同一个,则将锁标识修改为0,然后将队列的第一个线程唤醒。

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第19天,点击查看活动详情

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值