ThreadLocal源码解读

ThreadLocal

多线程访问同一个共享变量时特别容易出现并发问题,在多个线程需要对一个共享变量进行写入时。为了保证线程安全,在访问共享变量时需要进行适当的同步。

说一下什么是线程安全

如果你的代码在某个进程中有多个线程同时执行,这些线程可能会同时运行这段代码,如果每次运行结果和单线程运行的结果一致,并且得到的结果和预期也是一致的,那么就是线程安全。
而我们为了线程安全,同步的措施一般是加锁来保证线程安全,来控制不同线程对临界区的访问。但使用了锁,性能肯定会有所下降,加重了使用者的负担。

那ThreadLocal可以用来避免线程之间的竞争,使用它创建一个变量后,每个线程对其进行访问的是自己线程的变量

什么是ThreadLocal

ThreadLocal从名字来看,就是线程本地的意思,但可以从官方注释上看,大致意思是说ThreadLocal可以给我们提供一个线程内的局部变量,而且这个变量与一般的变量还不同,这个变量每个线程独有的,与其他线程互不干扰。

如何使用ThreadLocal

以下代码创建了2个线程,使用ThreadLocal去存取值,看看两个线程间会不会相互影响

对于A线程拿到的值肯定是"我是A线程",对于B线程拿到的值肯定是"我是B线程"。他们是怎么拿到呢? 他们不是共用一个threadLocal吗? 那肯定是get方法的问题,但是再看get方法之前先看一下ThreadLocal大致的结构。

public class ThreadLocalDemo {
    // 创建一个ThreadLocal对象,这里泛型指定为String
    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        Thread t1 = new Thread(()->{
            // 线程A设置的变量,那么此线程在调用get的方法时,就是将设置的变量拿出来。跟操作hashmap差不多
            threadLocal.set("我是A线程"); 
            System.out.println(Thread.currentThread().getName()+" -> " +threadLocal.get());
            removeLocalVariiable();
            System.out.println(Thread.currentThread().getName()+" -> " +threadLocal.get());
        },"A");


        Thread t2 = new Thread(()->{
            // 线程B设置的变量,
            threadLocal.set("我是B线程");
            System.out.println(Thread.currentThread().getName()+" -> " +threadLocal.get());
            removeLocalVariiable();
            System.out.println(Thread.currentThread().getName()+" -> " +threadLocal.get()); // 通过get方法拿到我们设置的值
        },"B");


        t1.start();
        TimeUnit.SECONDS.sleep(1);
        t2.start();
    }
    // 删除本地变量的方法,每个线程调用完后,如果不用之前设置的值后,记得要清理。
    public static void removeLocalVariiable(){
        threadLocal.remove();
    }
}

结构

Thread

Thread类中有两个属性,只有在线程调用ThreadLocal的set方法和get方法时候才会使用到它们,这个在源码中就会看到的。

ThreadLocal.ThreadLocalMap threadLocals = null;

ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
ThreadLocal

ThreadLocal的静态内部类ThreadLocalMap为每个Thread维护了一个数组table,ThreadLocal确定了一个数组下标,而这个下标就是value存储的对应位置。

ThreadLocalMap类似于map结构,当我们创建ThreadLocal时,我们通过set方法设置的本地变量,没有存放在ThreadLocal中,而是存放在了ThreadLocalMap中,当我们的线程在使用get方法时,再从当前线程的ThreadLocalMap里面将其拿出来使用。

说类似于map结构,那么其实就是维护了一张哈希表,也就是一个数组,代码中有一个Entry类型table表,这个表里面存储的就是我们的Entry对象。

那我们来看看Entry对象

// WeakReference为弱引用
static class ThreadLocalMap {
    //  类似于hashmap中的Entry
    static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
    }

WeakReference:弱引用,当垃圾回收器看到此引用没有被其他对象引用,那么就自动回收了,避免内存泄漏。但是只是key被回收了,value并没有被回收,value依然是强引用

以上就是ThreadLocal大致的结构,那么下面看看ThreadLocal的实现原理。

实现原理

ThreadLocal是如何做到各自线程只能看到自己本地的变量值,需要去看看代码中的set、get等一些方法,才能进一步了解。先看看get方法

get方法

代码中当前线程会返回一个ThreadLocalMap,这个ThreadLocalMap结构上文已经说明。代码中也做了详细的注释。

 public T get() {
        Thread t = Thread.currentThread(); // 获取当前线程
        ThreadLocalMap map = getMap(t); // 拿到当前线程对应的ThreadLocalMap
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this); // 获取我们要查找的entry对象
            if (e != null) {
                @SuppressWarnings("unchecked")
                // 上文如何使用ThreadLocal,如何拿到各自线程中的局部变量,就是在这里体现出来的
                T result = (T)e.value;  // 获取entry对象里面的value值,然后return返回
                return result;
            }
        }
        return setInitialValue();  //  map没有初始化的话,需要进行初始化
    }
    
    
// 此方法上文结构中,已经说了,是Thread类的属性,返回的是线程自己的threadLocals,其实也就是线程自己对应的ThreadLocalMap
ThreadLocalMap getMap(Thread t) {    return t.threadLocals;}

   
private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key) // 找打了对应的线程对应的key,直接返回该Entry。
                return e;
            else  // e=null   或者e.get() != key      如果总是访问存在的key,这个方法会永远进不来。这个else分支可以先忽略,下文中会有介绍。
                return getEntryAfterMiss(key, i, e);
        }
        
// 初始化的方法        
private T setInitialValue() {
        T value = initialValue(); // 这个方法返回的null,也就说当我们没有使用set方法设置变量,而是上来直接使用get方法,那么返回的一定是null,是这里起的作用
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t); 
        if (map != null)
            map.set(this, value);
        else
           // 如果我们没有使用set方法,那么在我们使用get方法时,也需要去初始化map,不过map中存储的是当前threadLocal作为key,value自然就是null了,前提是我们没有使用set方法
            createMap(t, value); // 初始化方法,下文set方法源码有所介绍。
        return value;
    }
set方法

对于上文中各个线程拿到各自的值,前提是我们使用了set方法,看看值是如何存储到ThreadLocalMap中的。

从下面代码可以看出每一个线程持有了一个ThreadLocalMap对象,且每一个新的线程都会去创建一个新的ThreadLocalMap,之后ThreadLocalMap存在了,直接使用。

public void set(T value) {
        Thread t = Thread.currentThread(); // 获取当前线程
        ThreadLocalMap map = getMap(t);// 获取当前线程的map
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value); // 线程如果第一次调用set方法就初始化当前线程对应的threadlocalmap。 Thread中threadLocals属性默认为null,赋值给threadlocals
    }
    
    
    
    
private void set(ThreadLocal<?> key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);  // 下标

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;  // key相同,覆盖旧值
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i); // 清除脏entry
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size
            // cleanSomeSlots方法返回false,说明没有槽位可以清理了,进一步判断size是否超过了阈值,如果超过了需要进行扩容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)  // 扩容
                rehash();
        }
            
初始化方法createMap()
// 初始化theadlocals
// createMap方法
void createMap(Thread t, T firstValue) {
		// this为当前线程对象, firstValue为value值  传入构造方法
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }


ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY]; // 初始化一个长度为16的map
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); // 计算下标值
            table[i] = new Entry(firstKey, firstValue);  // 将entry放入数组当中
            size = 1;
            setThreshold(INITIAL_CAPACITY); // 阈值  默认为16,也就说超过10就会进行扩容
        }

private void setThreshold(int len) {
            threshold = len * 2 / 3; //当哈希表的容量超过了总容量的2/3的时候就需要对哈希表进行扩容了
        }

上文中提到了弱引用,当不存在外部引用的时候,就会自动被回收,但是Entry中的value是强引用。只有当Thread被回收时,这个value才有被回收的机会,如果线程不退出,value总是会存在的。如果对于线程池来说,大部分线程会一直存在,这样,就会造成内存泄漏的问题,而以下几个方法专门用来清理这些脏Entry(key为null,但value没有回收),在TheadLocal中,使用set、get、remove方法都会进行清理。而清理的方法就是下面这几个方法。

我们来举个例子:
比如,ThreadLocalMap哈希表长度为8,其实默认初始化长度为16,这里举例为8。当前数组元素为{7,22,13,67,12,45}。此时key=22,key=13,key=45已经过期了。其他空白的地方为空,可以直接存放数据。

这时候来了一个新的数据,也就是我们使用set方法,比如key=67,value=“新的值”,通过计算下标,应该存放在下标为3的这个地方。此时会进入set方法中的这个if语句中。

if (k == null) {
                    replaceStaleEntry(key, value, i); // 清除脏entry
                    return;
 }
replaceStaleEntry方法
int slotToExpunge = staleSlot; 

            // 如果staleSlot为0,那么从最后一个长度往前遍历,如果staleSlot不为0,那么从staleSlot往前遍历
            // 前面或许有的key已经被回收,但是value以及entry还没有被释放,需要释放空间
            // 为了避免存在很多过期的对象占用,导致来了一个新的元素达到了阈值而触发一次新的rehash
            for (int i = prevIndex(staleSlot, len);  
                 (e = tab[i]) != null;
                 i = prevIndex(i, len)) 
                if (e.get() == null) key==null
                    slotToExpunge = i; // 记录数组左边第一个空的entry到staleSlot之间key过期最小的index下标

上面代码第一个for循环是向前遍历数据,遍历到Entry为空的时候就停止遍历。通过上文中的例子,当遍历到下标为0的地方就停止了。向前遍历的过程同时会找出过期的key,也就是下标为2这个key,记录下来。

 if (e.get() == null) key==null
                    slotToExpunge = i; // 记录数组左边第一个空的entry到staleSlot之间key过期最小的index下标

此时slotToExpunge为2,staleSlot=3

第二个for循环是向后遍历数据,找出是否有当前匹配的key,如果有重新设置值,并清理过期的对象,上文中的例子遍历到下标为4的位置,匹配到了当前的key。进入到这个代码中。

将旧值进行覆盖,并进行数据交换,此时slotToExpunge为2,staleSlot为3,i=4,这里会把3和4的位置的元素进行交换。交换后的样子如下:

为什么要交换呢?

如果不交换,直接清理下标为3的这个位置,下标为3的这个位置为空之后,可以直接放入数据,样子如下:

这个时候,我们把我们要设置的值直接放入里面

 if (k == key) { 找到了key覆盖旧值
                    e.value = value;

                    // 和之前过期的对象的进行交换位置
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                 
                    return;
                }

这样,整个数组就存在两个key=67的数据了,所以一定要交换数据。

(看下方代码)slotToExpunge == staleSlot 说明for循环往前查找的时候没有找到过期的,由于前面过期的对象已经通过交换位置的方式放到了i上了,所以需要清理的位置是i,而不是传过来的staleSlot。

                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    // 清理过期数据
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);

i=4,直接清理这个位置的数据,因为key为null,value还存在引用,会造成内存泄漏问题。

下方的代码的意思是:往前的for循环,没有找到过期的,整个数组也没有找到key,那么会直接设置到staleSlot这个位置上。

可以看出,不管数组是否找到了key,最终都会将key交换到staleSlot的位置上,不管如何,staleSlot位置上存放的都是有效的值,不需要进行清理。

                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i; // 这里的i 是向后遍历for循环拿到的第一个过期对象的位置
            }
            // 如果key在数组中没有存在,那么直接在当前位置创建这个entry对象
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);
            // 如果有其他已经过期的对象,那么需要清理
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
        
cleanSomeSlots方法

进行一定次数的循环,从当前位置i开始往后循环的去寻找脏Entry,也就是key=null的脏entry,然后进行删除。

private boolean cleanSomeSlots(int i, int n) {
           boolean removed = false;
           Entry[] tab = table;
           int len = tab.length;
           do {
               i = nextIndex(i, len);
               Entry e = tab[i];
               if (e != null && e.get() == null) {
                   n = len;
                   removed = true;
                   i = expungeStaleEntry(i); // 这个清理过程只是覆盖了一段范围,并不是全部区间。
               }
               // n >>>= 1 说明要循环log2N次。在没有发现脏Entry时,会一直往后找下个位置的entry是否是脏的,如果是的话,就会使 n = 数组的长度。然后继续循环log2新N 次。
           } while ( (n >>>= 1) != 0);
           return removed;
}
expungeStaleEntry方法

这个方法是帮助垃圾回收的,在set、get、remove方法都会见到这个方法,这个方法是专门用来回收value的方法,用来检查key是否被回收,如果被回收了,进一步才回收它的value。正常情况下并不会出现内存溢出,但是如果我们没有调用get和set的时候就会面临着内存溢出。所以当我们不在使用变量时记得调用remove方法,避免内存溢出。

比如在经过replaceStaleEntry方法后,进入expungeStaleEntry方法时,map结构如下:

staleSlot为4

   // expunge entry at staleSlot
            tab[staleSlot].value = null;  // 清理当前位置
            tab[staleSlot] = null;
            size--;

经过上方代码后,会把下标为4的位置置为null,结构如下:

下面的代码当遍历到下标为5的位置,经过hashcode计算下标,得到下标为4的话,h!=i,说明之前冲突过,那么将下标为5的entry放入下标为4的位置上,也就是这样的结构:

然后继续遍历,发现key为null的元素,直接清除,结构如下:

之后就退出循环了。

 // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) { // 置null,防止内存泄漏。gc回收
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                
                    int h = k.threadLocalHashCode & (len - 1); // 下标
                    // 这里如果不相等,说明之前冲突过了
                    if (h != i) {
                        tab[i] = null;  
                        while (tab[h] != null)
                            h = nextIndex(h, len); 
                        tab[h] = e;  // 这里体现出开放地址法
                    }
                }
            }
            return i;
清理value的方法

此方法在get方法中有所体现

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i); // 清除value
                else
                    i = nextIndex(i, len); // 遇到冲突 使用开放地址法,处理冲突
                e = tab[i];
            }
            return null;
        }
resize扩容方法
private void rehash() {
            expungeStaleEntries(); // 清理过期的Entry,也就是脏Entry

            // Use lower threshold for doubling to avoid hysteresis
            // 10-10/4=8,在清理过期Entry后如果长度大于等于8,则进行扩容
            if (size >= threshold - threshold / 4)
                resize();
        }
        
private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);  // 清除脏Entry
            }
        }

private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;  // 扩容两倍
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) { // 旧数组的元素往新数组上进行转移
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen); // 设置新的阈值
            size = count;
            table = newTab;
        }
remove方法

当我们不使用变量,记得调用一下remove方法,将key为null,value不为null的Entry回收掉,这样可以避免内存溢出。 就跟我们平常使用锁时,上了锁也要记得解锁。

public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());  // 当前线程的map
         if (m != null)
             m.remove(this);
     }

private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear(); // 将弱引用置为null,回收
                    expungeStaleEntry(i); // 清理value
                    return;
                }
            }
        }    
hash冲突处理

TheadLocalMap中使用的是线性探测法,如果发生了元素冲突,那么就使用下一个槽位存放,只要散列表足够大,空的槽位总会找到的。
例如:比如当前下标为3的位置上已经有对象了,那么如果遇到了冲突,就使用下一个槽位存放,也就是4这个位置。

为什么使用线程探测法呢?
  • ThreadLocal中有一个属性 private static final int HASH_INCREMENT = 0x61c88647; 这个HASH_INCREMENT是一个神奇的魔数,可以让哈希码能够均匀的分布在2的N次方的数组里。这个魔数百度有很多博客讲解,可以自行查询。
  • ThreadLocal数据量不会很大,调用各种方法随时都会清除key为null的脏Entry,会节省空间,数组的查询效率也非常高,并且因为第一点冲突的概率可以说是很低很低。
ThreadLocal不支持继承性

先看例子:

public class ThreadLocalDemo01 {
    private static ThreadLocal<String> threadLocal = new ThreadLocal<String>();

    public static void main(String[] args) throws InterruptedException {
        threadLocal.set("我是main线程");


        Thread sonThread = new Thread(()->{
            System.out.println(Thread.currentThread().getName()+"->"+threadLocal.get());
        },"子线程");

        sonThread.start();  // 启动子线程

        TimeUnit.SECONDS.sleep(1);

        System.out.println(Thread.currentThread().getName()+"->"+threadLocal.get());
    }
}

运行结果:
子线程->null
main->我是main线程

同一个ThreadLocal变量在父线程中被设置值后,在子线程中是获取不到的。因为在子线程thread里面调用get方法时是当前线程,而调用set方法的线程变量是main线程,两者是不同的线程。自然子线程是返回null的。如果想让子线程能够访问到父线程中的值可以使用InheritableThreadLocal这个类。

InheritableThreadLocal

Thead类中的一个属性:

ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
InheritableThreadLocal代码

继承于ThreadLocal类,重写了ThreadLocal类中的三个方法

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    
    protected T childValue(T parentValue) {
        return parentValue;
    }

  
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }


    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}
测试代码

这样就可以拿到了

public class InheritableThreadLocal01 {
    private static InheritableThreadLocal<String> threadLocal = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        threadLocal.set("main线程");

        Thread sonThread = new Thread(()->{ // 子线程
            System.out.println(Thread.currentThread().getName()+"->"+threadLocal.get());
        },"子线程");
        sonThread.start();

        System.out.println(Thread.currentThread().getName()+"->"+threadLocal.get());
    }
}

Thread中的init方法中代码如下:

具体看一下代码:

private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table; // 父线程的table
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) { // 父线程的值复制到子线程中
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }
总结
  • ThreadLocal用来提供线程局部变量的,可以在线程内随时随地存取数据,线程之间互不干扰。
  • ThreadLOcal实际上在每个线程内部维护了一个TheadLocalMap,ThreadLocalMap每个线程独有,里面存储的是Entry对象,Entry对象实际上是一个ThreadLocal的实例的弱引用,同时还保存了强引用的value,存储的键值对的形式的值。key就是ThreadLocal实例本身,value是要存储的数据。
  • 关于内存泄漏问题,在使用set、get、remove方法时都会清除ThreadLocal中key为null的Entry。如果不清除,会造成内存泄漏的问题。用完记得要使用remove方法清理一下。
  • 关于继承,每个线程都可以访问到父线程传递过来的一个数据,但是变量的传递发生在线程创建的时候,如果不是新建线程,而是复用了线程池里的线程,就不行了。
    public static void main(String[] args) 
      // 不管使用哪种方式,都需要在每个线程执行完成时,应该调用remove方法清理ThreadLocal
        private static ThreadLocal local = new ThreadLocal();
        InheritableThreadLocal local = new InheritableThreadLocal();

        ExecutorService fixed = Executors.newFixedThreadPool(4);
        for(int i = 0; i< 5; i++){
            int num = i;
            local.set("父线程"+num);
            fixed.execute(()->{
                System.out.println(Thread.currentThread().getId()+" "+Thread.currentThread().getName()+"拿到的线程变量为:"+local.get());
                //local.remove(); // 调用remove方法
            });
        }
        fixed.shutdown();
    }

没有调用remove方法

调用remove方法

参考资料:《Java并发编程之美》

上述存在问题,还请指出,谢谢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值