ThreadLocal 源码深入分析

ThreadLocal--见名知意,线程本地变量,每个线程都有一个独立的副本,我们来看看JDK文档里面是怎么描述这个类的:

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its get or set method) has its own, independently initialized copy of the variable. ThreadLocal instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID).

意思就是说:该类提供了线程本地变量。这些变量与正常的对应变量的不同之处在于,每个线程访问该变量时(通过其get或set方法)都有自己独立初始化的变量副本。ThreadLocal通常是类的private static 字段,来保存和线程相关的状态(例如,用户ID或事务ID)。

再来看JDK给的一个例子,给每个线程生成一个独一无二的标识:

 public class ThreadId {
     // Atomic integer containing the next thread ID to be assigned
     private static final AtomicInteger nextId = new AtomicInteger(0);

     // Thread local variable containing each thread's ID
     private static final ThreadLocal<Integer> threadId =
         new ThreadLocal<Integer>() {
             @Override protected Integer initialValue() {
                 return nextId.getAndIncrement();
         }
     };

     // Returns the current thread's unique ID, assigning it if necessary
     public static int get() {
         return threadId.get();
     }
 }

下面就进入源码分析的地步了,首先我们来看看成员变量和构造函数:

public class ThreadLocal<T> {
    /**
     * ThreadLocals rely on per-thread linear-probe hash maps attached
     * to each thread (Thread.threadLocals and
     * inheritableThreadLocals).  The ThreadLocal objects act as keys,
     * searched via threadLocalHashCode.  This is a custom hash code
     * (useful only within ThreadLocalMaps) that eliminates collisions
     * in the common case where consecutively constructed ThreadLocals
     * are used by the same threads, while remaining well-behaved in
     * less common cases.
     */
    private final int threadLocalHashCode = nextHashCode();
    /**
     * The next hash code to be given out. Updated atomically. Starts at
     * zero.
     */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * The difference between successively generated hash codes - turns
     * implicit sequential thread-local IDs into near-optimally spread
     * multiplicative hash values for power-of-two-sized tables.
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * Returns the next hash code.
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    /**
     * Creates a thread local variable.
     * @see #withInitial(java.util.function.Supplier)
     */
    public ThreadLocal() {
    }
}

可以看到,ThreadLocal类里面只有一个实例字段:threadLocalHashCode,从字面意思我们也可以看出来和HashCode有关,这个字段与后面的ThreadLocalMap有关(后面会详细分析),ThreadLocalMap是一个Map(ThreadLocalMap没有实现Map接口),且这个Map是使用线性探测来解决hash冲突的。这个Map里面的key就是ThreadLocal,value是我们要存储的值,既然是一个Map,那么当然需要一个hash值来确定其最终的位置了,该字段就是用来确定每一个ThreadLocal类的hash值的(一个线程可以有多个ThreadLocal类变量,就是通过这些变量的threadLocalHashCode来找到其对应的值)。另外两个静态变量AtomicInteger ,HASH_INCREMENT和一个静态函数nextHashCode是用来确定每一个ThreadLocal的hash值的。ThreadLocal只有一个最简单的无参构造函数。

下面我们再来看看ThreadLocal暴露了那些方法:

    public T get() {
        Thread t = Thread.currentThread();  // 获取当前线程
        ThreadLocalMap map = getMap(t);    // 得到和线程绑定的ThreadLocalMap
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);  // 根据key(ThreadLocal变量),找到对应的value
            if (e != null) {                       // 找到了
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();          // 没找到
    }

get方法,取出线程本地变量,首先获取当前线程,然后得到和线程绑定的ThreadLocalMap,根据key(this)去找value,找到了就直接返回,没找到则调用setInitialValue函数。

    private T setInitialValue() {
        T value = initialValue();           // 看是否设置了初始值
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);       //  得到和该线程绑定的map
        if (map != null)
            map.set(this, value);            // 设置值,这样下次就可以直接get到了
        else
            createMap(t, value);             // 如果map没有创建就创建map
        return value;                     // 返回默认值
    } 

    /**
     * Returns the current thread's "initial value" for this
     * thread-local variable.  This method will be invoked the first
     * time a thread accesses the variable with the {@link #get}
     * method, unless the thread previously invoked the {@link #set}
     * method, in which case the {@code initialValue} method will not
     * be invoked for the thread.  Normally, this method is invoked at
     * most once per thread, but it may be invoked again in case of
     * subsequent invocations of {@link #remove} followed by {@link #get}.
     *
     * <p>This implementation simply returns {@code null}; if the
     * programmer desires thread-local variables to have an initial
     * value other than {@code null}, {@code ThreadLocal} must be
     * subclassed, and this method overridden.  Typically, an
     * anonymous inner class will be used.
     *
     * @return the initial value for this thread-local
     */
    protected T initialValue() {
        return null;                // 默认为null
    }

setInitialValue就是看是否设置了初始值,如果设置了就返回初始值,并将该值set到ThreadLocalMap中,下次就可以直接get了。默认返回的是null,使用protected修饰,我们可以使用继承或者匿名内部类来设置初始化值。在JDK8中还提供了一个接口给我们更加方便的设置初始值:

    /**
     * Creates a thread local variable. The initial value of the variable is
     * determined by invoking the {@code get} method on the {@code Supplier}.
     *
     * @param <S> the type of the thread local's value
     * @param supplier the supplier to be used to determine the initial value
     * @return a new thread local variable
     * @throws NullPointerException if the specified supplier is null
     * @since 1.8
     */
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        return new SuppliedThreadLocal<>(supplier);
    }

    /**
     * An extension of ThreadLocal that obtains its initial value from
     * the specified {@code Supplier}.
     */
    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

初始值由传入的Supplier提供。比如我们可以这样:

ThreadLocal<String> threadLocal = ThreadLocal.withInitial(() -> "Hello Java");

这样就给ThreadLocal附上了初始值了,在没有set新值时,返回的就是“Hello Java”了。

下面我们再来看看前面提到的getMap和createMap长什么样呢?

    /**
     * Get the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param  t the current thread
     * @return the map
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    /**
     * Create the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the map
     */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

可以看到getMap返回的是线程里面的变量threadLocals,createMap也就是给Thread里的threadLocals设置,我们跟进去看看:

    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

上面是Thread中的源码。

到这里我们基本就已经知道了线程本地变量的实现原理了,首先要明确线程本地变量不是存储在ThreadLocal类里面,而是存储在Thread里面!!!,每一个Thread类都对应着操作系统的一个线程。ThreadLocal类只是提供了一层“皮”,用于统一访问这些线程本地变量。由于每次访问的都是和具体线程绑定的ThreadLocalMap,所以每个线程看到的值当然不一样了,因为他们在Thread类里面,是线程私有的。

再看看暴露的另外两个暴露的方法set和remove:

    /**
     * Sets the current thread's copy of this thread-local variable
     * to the specified value.  Most subclasses will have no need to
     * override this method, relying solely on the {@link #initialValue}
     * method to set the values of thread-locals.
     *
     * @param value the value to be stored in the current thread's copy of
     *        this thread-local.
     */
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    /**
     * Removes the current thread's value for this thread-local
     * variable.  If this thread-local variable is subsequently
     * {@linkplain #get read} by the current thread, its value will be
     * reinitialized by invoking its {@link #initialValue} method,
     * unless its value is {@linkplain #set set} by the current thread
     * in the interim.  This may result in multiple invocations of the
     * {@code initialValue} method in the current thread.
     *
     * @since 1.5
     */
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

可以看到这两个方法都比较简单。

接下面我们再来分析一下,前面一直提到的ThreadLocalMap的实现原理:

static class ThreadLocalMap {
     static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         */
        private int size = 0;

        /**
         * The next size value at which to resize.
         */
        private int threshold; // Default to 0
}

与HashMap里面的还是比较相似的,Entry,存储key, value。key为ThreadLocal类型, value就是我们要存储的值。Entry继承自WeakReference(弱引用)方便垃圾回收。hash表 table, 表内的元素数量:size, 默认初始容量16, 门限threshold,大于此值时进行rehash。

几个简单的函数:

        /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;          // 门限为2/3 这与HashMap不同,HashMap内是3/4
        }

        /**
         * Increment i modulo len.
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * Decrement i modulo len.
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

构造函数:

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

由于哈希表的长度为2的幂次方,所有使用了与运算代替取模运算,这一点与HashMap中是一样的。

从Map中获取值:

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

直接根据key的threadLocalHashCode 作为hash值,计算哈希槽的位置,如果当前槽位就是我们要找的值的话,直接放回,或者执行getEntryAfterMiss。

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)      // 找到了,返回
                    return e;
                if (k == null)      // 找到了一个entry的key为null,说明此key调用了remove函数
                    expungeStaleEntry(i);   // 清理这个无效的key
                else
                    i = nextIndex(i, len);   // 没有找到,继续去找下一个槽位
                e = tab[i];
            }
            return null;
        }

可以看到逻辑还是比较清晰的,就是从通过哈希值计算得到的槽位一直往后找,知道遇见了null,说明该key不在map中,则返回null。找到了就返回,如果在这个过程中遇到了key为null(entry不为null),则调用expungeStaleEntry清理这个无用的entry。从这里我们就可以发现了ThreadLocalMap中解决Hash冲突的方式是使用线性探测法,而不是HashMap中的拉链法!!

private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

expungeStaleEntry用于清理一个无效的entry(就是调用了remove方法的)然后接着从当前位置往后直到null,进行rehash或者清理无用entry。

我们举个例子:加入hash表内存储的元素为:null null  3  4  5  6  null null。其中3, 4, 5 ,6 对与运算后得到相同的值,由于解决hash冲突使用的线性探测法,所以他们会聚集在一起。如果此时4倍remove掉了,则hash表变成了:null null  3  null  5  6  null null。此时我们需要对4之后直到null的值进行rehash,否则之后我们就get不到他们了,rehash之后应该是:null null  3  5  6  null null null

        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

逻辑也是比较清晰的,如果发现了相同的key就替换旧的值,发现无效的entry了就使用replaceStaleEntry来替换我们要设置的key,否则通过线性探测法找一个空的hash槽,将key,value插进去。set完后,还会调用cleanSomeSlots启发式的进行垃圾清理,如果size大于门限,会执行rehash。

        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

通过线性探测法找到对应的entry,调用其clear方法:就是将key置为null。然后调用expungeStaleEntry清理该无效的entry。

 private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

 具体实现不再深究,这替换过程里面也进行了不少的垃圾清理动作以防止引用关系存在而导致的内存泄露。

        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

启发式的进行垃圾清理,复杂度为O(lg(n))

        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)  // 可以看到实际的负载因子时 0.5(2/3 - 1/6 = 1/2)
                resize();
        }

        private void resize() {   // 扩容,数组长度为原来的两倍
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }

        private void expungeStaleEntries() {  // 全表的垃圾清理
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }

总结

每个 Thread 里都含有一个 ThreadLocalMap 的成员变量,这种机制将 ThreadLocal 和线程巧妙地绑定在了一起,即可以保证无用的 ThreadLocal 被及时回收,不会造成内存泄露,又可以提升性能。假如我们把 ThreadLocalMap 做成一个 Map<t extends Thread, ?> 类型的 Map,那么它存储的东西将会非常多(相当于一张全局线程本地变量表),这样的情况下用线性探测法解决哈希冲突的问题效率会非常差。而 JDK 里的这种利用 ThreadLocal 作为 key,再将 ThreadLocalMap 与线程相绑定的实现,完美地解决了这个问题。

总结一下什么时候无用的 Entry 会被清理:

  • Thread 结束的时候
  • 插入元素时,发现 staled entry,则会进行替换并清理
  • 插入元素时,ThreadLocalMap 的 size 达到 threshold,并且没有任何 staled entries 的时候,会调用 rehash 方法清理并扩容
  • 调用 ThreadLocalMap 的 remove 方法或set(null) 时

尽管不会造成内存泄露,但是可以看到无用的 Entry 只会在以上四种情况下才会被清理,这就可能导致一些 Entry 虽然无用但还占内存的情况。因此,我们在使用完 ThreadLocal 后一定要remove一下,保证及时回收掉无用的 Entry。

特别地,当应用线程池的时候,由于线程池的线程一般会复用,Thread 不结束,这时候用完更需要 remove 了。

总的来说,对于多线程资源共享的问题,同步机制采用了 以时间换空间 的方式,而 ThreadLocal 则采用了 以空间换时间 的方式。前者仅提供一份变量,让不同的线程排队访问;而后者为每一个线程都提供了一份变量,因此可以同时访问而互不影响。

 

参考:并发编程 | ThreadLocal 源码深入分析

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值