5、ThreadLocal源码

ThreadLocal大家都知道是java利用空间换取时间的一个典型的类,他到底解决了什么?在什么场景下使用?需要注意什么?现在带着这些问题进一步了解吧。

一、ThreadLocal解决什么问题

解决线程安全的两种思路,第一种:线程同步(锁)的方式,让竞争的线程串行执行。 第二种:线程隔离,让参与竞争的线程独享一份副本资源。从其原理上来看ThreadLocal比线程同步的方式拥有更大的并发性,但是这种以空间换时间的方式,也是以牺牲内存资源为前提的,所以在大量的使用ThreadLocal的场景,一定要做内存释放操作防止内存泄漏。

1、弱引用(WeakReference)

WeakReference的意思是弱引用, 当一个对象仅仅被weak reference(弱引用)指向, 而没有任何其他strong reference(强引用)指向的时候, 如果这时GC运行, 那么这个对象就会被回收,不论当前的内存空间是否足够,这个对象都会被回收。 详细参考git源码 WeakReferenceTest.java

2、弱引用与软引用对比

弱引用的对象拥有更短暂的生命周期,被垃圾回收器回收的时机不一样,在垃圾回收器线程扫描它所管辖的内存区域的过程中,一旦发现了只具有弱引用的对象,不管当前内存空间足够与否,都会回收它的内存。而被软引用关联的对象只有在内存不足时才会被回收。弱引用不会影响GC,而软引用会一定程度上对GC造成影响。 (强引用>软引用>弱引用>虚引用)

3、为什么ThreadLocal使用弱引用

其实无论ThreadLocalMap中的key使用哪种类型引用都无法完全避免内存泄漏,跟使用弱引用没有关系。

​要避免内存泄漏有两种方式:

1、使用完ThreadLocal,调用其remove方法删除对应的Entry。

2、使用完ThreadLocal,当前Thread也随之运行结束。

 相对第一种方式,第二种方式显然更不好控制,特别是使用线程池的时候,线程结束是不会销毁的。也就是说,只要记得在使用完ThreadLocal及时的调用remove,无论key是强引用还是弱引用都不会有问题。那么为什么key要用弱引用呢?事实上,在ThreadLocalMap中的set/getEntry方法中,会对key为null(也即是ThreadLocal为null)进行判断,如果为null的话,那么是会对value置为null的。

这就意味着使用完ThreadLocal,CurrentThread依然运行的前提下,就算忘记调用remove方法,弱引用比强引用可以多一层保障:弱引用的ThreadLocal会被回收,对应的value在下一次ThreadLocalMap调用set,get,remove中的任一方法的时候会被清除,从而避免内存泄漏

4、InheritableThreadLocal

InheritableThreadLocal的意识是可继承的ThreadLocal,解决跨线程无法传参问题,例如子线程无法获取父线程变量,实现原理在Thread.init方法复制父线程值给子线程。 详细参考git源码 InheritableThreadLocalTest.java

5、TransmittableThreadLocal

TransmittableThreadLocal是阿里巴巴开源的专门解决InheritableThreadLocal的局限性,实现线程本地变量在线程池的执行过程中,能正常的访问父线程设置的线程变量。需要下载阿里jar 详细参考git源码 [备注:jdk1.8下 InheritableThreadLocal也可以跨池访问和预期不符合]

二、线程安全

如果一个类可以安全地被多个线程使用,它就是线程安全的。你无法对此论述提出任何争议,但也无法从中得到更多有意义的帮助。那么我们如何辨别线程安全与非线程安全的类?我们甚至又该如何理解“安全”呢?任何一个合理的“线程安全性”定义,其关键在于“正确性”的概念。在<<JAVA并发编程实践>>书中作者是这样定义的:一个类是是线程安全的,是指在被多个线程访问时,类可以持续进行正确的行为。或当多个线程访问一个类时,如果不用考虑这些线程在运行时环境下的调度和交替执行,并且不需要额外的同步及在调用方代码不必作其他的协调,这个类的行为仍然是正确的,那么称这个类是线程安全的。

三、TheadLocalMap数据结构

四、ThreadLocal类重点方法列举

属性名称

描述

get

获取当前线程在ThreadLocal中的缓存

set

将值保存当前线程中

withInitial

创建线程安全变量 比如:private ThreadLocal<Integer> balance = ThreadLocal.withInitial(() -> 1000);

五、ThreadLocal源码解读

public class ThreadLocal<T> {
    /**
     * 获取threadLocal hashcode 
     */
    private final int threadLocalHashCode = nextHashCode();

    /**
     * nextHashCode为原子变量.
     */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * 1640531527.
     * 计算规则:详细见散列计算规则文档 我没有看懂TODO
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * 返回 hashcode.
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    /**
     * 初始化null
     */
    protected T initialValue() {
        return null;
    }

    /**
     * 创建线程局部安全变量 demo参考:ThreadLocalWithInitialTest.java
     */
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        return new SuppliedThreadLocal<>(supplier);
    }

    /**
     * 空构造方法
     */
    public ThreadLocal() {
    }

    /**
     * 获取当前线程值
     */
    public T get() {
        Thread t = Thread.currentThread();
        //从当前线程拿到ThreadLocalMap变量
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //从map对象拿到Entry实体
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //返回初始化值
        return setInitialValue();
    }

    /**
     * ThreadLocal初始化
     */
    private T setInitialValue() {
        //value初始化
        T value = initialValue();
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            //初始化ThreadLocalMap
            createMap(t, value);
        return value;
    }

    /**
     * ThreadLocal保存值
     */
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    /**
     * 删除当前线程ThreadLocalMap
     */
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

    /**
     * 获取当前线程ThreadLocalMap
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    /**
     * 当线程无ThreadLocalMap对象初始化ThreadLocalMap对象
     */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

    /**
     * 初始化ThreadLocalMap对象
     */
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }

    /**
     * TODO 为什么抛出异常
     */
    T childValue(T parentValue) {
        throw new UnsupportedOperationException();
    }

    /**
     * 构建SuppliedThreadLocal实体
     */
    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

    /**
     * ThreadLocalMap结构定义
     */
    static class ThreadLocalMap {

        /**
         * Entry继承WeakReference 当GC会回收ThreadLocal对象 如果有人在使用怎么办?TODO
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * 初始化扩容容量 大小2的N次方
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * Entry实体 大小2的N次方
         */
        private Entry[] table;

        /**
         * Entry实体对象数量.
         */
        private int size = 0;

        /**
         * 负载系数
         */
        private int threshold; // Default to 0

        /**
         *  threshold 最坏的情况下 2/3*len 负载系数
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * 获取下条记录脚码
         */
        private static int nextIndex(int i, int len) {
            //i+1 小于 len
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * 获取上条记录脚码.
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

        /**
         * 保存值入ThreadLocalMap
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            //初始化Entry table容量大小
            table = new Entry[INITIAL_CAPACITY];
            //数组保存位置 与计算
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            //设置容量
            setThreshold(INITIAL_CAPACITY);
        }

        /**
         * ThreadLocalMap对象保存
         */
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            //ThreadLocalMap获取Entry对象
            Entry[] parentTable = parentMap.table;
            // 获取Entry长度
            int len = parentTable.length;
            //设置Threshold长度
            setThreshold(len);
            //构建Entry对象
            table = new Entry[len];

            //循环遍历parentMap
            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    //获取ThreadLocal对象
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        /**
                         * TODO why 此处会抛出异常 
                         */ 
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        // 定位到key的坐标位置
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            // 获取key的 next坐标
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

        /**
         * 在table中查找Entry对象
         */
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                //如果没有从table中命中目标,从table遍历获取目标
                return getEntryAfterMiss(key, i, e);
        }

        /**
         * 如果没有从table中命中目标,从table遍历获取目标
         */
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                //遍历到目标后返回
                if (k == key)
                    return e;
                //遍历中如果key为null 清理i槽位值
                if (k == null)
                    expungeStaleEntry(i);
                else
                    //获取下条记录的脚码
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

        /**
         * 保存ThreadLocal对象key及value
         */
        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            //获取key在Entry对象集合脚码位置
            int i = key.threadLocalHashCode & (len-1);

            //遍历tab
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                //如果插槽k等于key,将结果插入插槽。预分配内存
                if (k == key) {
                    e.value = value;
                    return;
                }
                //如果在插槽找不到k脚码 替换掉过期的key
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //初始化Entry保存到tab中
            tab[i] = new Entry(key, value);
            //Entry实体对象数量自增.
            int sz = ++size;
            //清理无效槽位 对table缩表或者扩容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

        /**
         * 删除key对应实体.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            //查找key卡槽id
            int i = key.threadLocalHashCode & (len-1);
            // 遍历tab Entry集合
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {    
                if (e.get() == key) {
                    // 清除Entry集合
                    e.clear();
                    //清理i槽位值
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

        /**
         * 替换掉过期的值将新值覆盖过期值
         * @staleSlot 目标key的插槽脚码
         * @len tab长度
         */
        private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
            
            //初始化即将删除的slotToExpunge插槽脚码
            int slotToExpunge = staleSlot;
            //i:staleSlot下个节点,prevIndex(i, len)是下下个节点
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            //此处代码同上 为什么不一次性遍历 TODO
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                //如果table key等于传入的key
                if (k == key) {
                    //对table value完成覆盖
                    e.value = value;
                    //目标脚码对象覆盖table已存入对象
                    tab[i] = tab[staleSlot];
                    //目标脚码对象赋值给e
                    tab[staleSlot] = e;

                    // 即将删除的插槽 等于 目标插槽
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    //清理过时及key为null的对象    
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // key等于null && 即将删除的插槽等于目标插槽
                if (k == null && slotToExpunge == staleSlot)
                    //即将删除的插槽等于当前循环脚码
                    slotToExpunge = i;
            }

            // 覆盖目标插槽value为null
            tab[staleSlot].value = null;
            //初始化目标插槽
            tab[staleSlot] = new Entry(key, value);

            // 如果即将删除插槽不等于目标插槽
            if (slotToExpunge != staleSlot)
                //清理过时及key为null的对象 
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

        /**
         * 这个函数是ThreadLocal中核心清理函数,它做的事情很简单:
         * 就是从staleSlot开始遍历,将无效(弱引用指向对象被回收)清理,
         * 即对应entry中的value置为null,将指向这个entry的table[i]置为null,直到扫到空entry。
         * 另外,在过程中还会对非空的entry作rehash。
         * 可以说这个函数的作用就是从staleSlot开始清理连续段中的slot(断开强引用,rehash slot等)
         */
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            // 删除staleSlot脚码元素
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            //Entry集合数减减
            size--;

            Entry e;
            int i;
            // 从staleSlot下一位开始到len结束遍历整个table, 
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                //k等于null 清理Entry并缩短其长度
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    //获取k的槽位编码
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
                        // 递归遍历tab
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        //此处设计精妙 双指针遍历    
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

        /**
         * 扫描从i到n的table
         * 清理key为null对象及过时的对象.
         */
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                //获取i next脚码
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    //删除e.get()为null的线程,降低tab体积
                    n = len;
                    removed = true;
                    //删除过时的对象
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

        /**
         * table回收或者扩容
         */
        private void rehash() {
            expungeStaleEntries();

            // 当容量达到总容量的3/4,进行扩容
            if (size >= threshold - threshold / 4)
                resize();
        }

        /**
         * 扩容实现
         */
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            //扩容在老容量基础上扩大2倍
            int newLen = oldLen * 2;
            //初始化Entry容量
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                //遍历老Entry对象
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    //获取Entry k元素
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        //根据key hashcode定位到元素在新数组的位置
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        //对新集合元素赋值    
                        newTab[h] = e;
                        count++;
                    }
                }
            }
            //设置负载系数
            setThreshold(newLen);
            size = count;
            table = newTab;
        }

        /**
         * 清理Entry key过期或者为null的key
         */
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
    }
}

参考文档:

   涉及其他知识文档:

  1. 文档涉及源码
  2. ThreadLocal之全链路跟踪

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值