深入理解ThreadLocal

来自一个学长大佬的分享,在此谢谢学长。
JavaSE Thread JUC

前言
以前虽然也接触过ThreadLocal,但是每当细问它的实现原理时,总是有一些地方说不清,归根到底还是对ThreadLocal的理解不够深入,只是停留在表面,今天就来总结一下ThreadLocal。

接下来的源码分析会用到,所以先放在前面温故一下。

简述
java.lang.ThreadLocal可用来存放线程的局部变量,每个线程都有单独的局部变量,彼此之间不会共享。它并不是一个Thread,而是一个线程的局部变量,也许把他命名为ThreadLocalVariable更合适

方法
public T get()  //返回当前线程的局部变量
public void set(T value)
protected T initialValue() //返回当前线程的局部变量的初始值
public void remove()
这里只解释一下initialValue()方法,其他几个方法比较容易理解。 initialValue()方法为protected类型,它是为了被子类覆盖而特意提供的,该方法返回当前线程的局部变量的初始值。这个方法是一个延迟调用方法,当线程第一次调用ThreadLocal对象的get()或者set()方法时才执行,并且仅执行一次。在ThreadLocal类本身的实现中,initialValue()方法直接返回一个null,所以一般使用ThreadLocal,最好重写initialValue()方法,否则先调用set()方法会报空指针异常。

protected T initialValue() {
    return null;
}
原理探究
 

我们看到ThreadLocal的类图(注:ThreadLocalMap是ThreadLocal的内部类,Entry是ThreadLocalMap的内部类。)
Thread有一个类型为ThreadLocalMap的变量threadLocals,用于存储该Thread拥有的用户变量,变量具体怎么存储交给了ThreadLocalMap,我们再看ThreadLocalMap,它有一个类型为Entry的数组变量table,用于存储用户变量,这里是数组,自然是可以存储多个用户变量了。一个Entry带代表了一个变量 一个用户变量又怎么存储的呢?我们还要再看看Entry,看他的构造函数就能一目了然。

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;
     Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}
其中,key类型为ThreadLocal,值就是用户设置的值。

set方法源码
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}
getMap的源码

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}
//Thread类中的属性
/* ThreadLocal values pertaining to this thread. This map is maintained by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
由于Thread类中ThreadLocal.ThreadLocalMap threadLocals没有修饰符,所以为default,即同包下可以使用。故在getMap中直接返回的t.threadLocals,因为ThreadLocal和Thread均在java.lang包下。

set方法的逻辑:当前Thread的threadLocals变量是否为null?,假设不为null,进入if,看到ThreadLocalMap#set方法.

ThreadLocalMap源码分析
存储结构
ThreadLocalMap中存储的是ThreadLocalMap.Entry(为了书写简单,后面直接写成Entry对象)对象。因此,在ThreadLocalMap中管理的也就是Entry对象。也就是说,ThreadLocalMap里面的大部分函数都是针对Entry的。

首先ThreadLocalMap需要一个“容器”来存储这些Entry对象,ThreadLocalMap中定义了Entry数组实例table,用于存储Entry。

 private Entry[] table;
也就是说,ThreadLocalMap维护一张哈希表(一个数组),表里面存储Entry。既然是哈希表,那肯定就会涉及到加载因子,即当表里面存储的对象达到容量的多少百分比的时候需要扩容。ThreadLocalMap中定义了threshold属性,当表里存储的对象数量超过threshold就会扩容。

/**
 * The next size value at which to resize.
 */
private int threshold; // Default to 0
/**
 * Set the resize threshold to maintain at worst a 2/3 load factor.
 */
private void setThreshold(int len) {
    threshold = len * 2 / 3;
}
从上面代码看出,加载因子设置为2/3。即每次容量超过设定的len的2/3时,需要扩容。

存储Entry对象
首先看看数据是如何被放入到哈希表里面:

ThreadLocalMap#set()
    private void set(ThreadLocal<?> key, Object value) {
        // We don't use a fast path as with get() because it is at
        // least as common to use set() to create new entries as
        // it is to replace existing ones, in which case, a fast
        // path would fail more often than not.
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        for (Entry e = tab[i];
            e != null;
            e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();
            if (k == key) {
                e.value = value;
                return;
            }
            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }
        tab[i] = new Entry(key, value);
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }
从上面代码中看出,通过key(ThreadLocal类型)的hashcode来计算存储的索引位置i。如果i位置已经存储了对象,那么就往后挪一个位置依次类推,直到找到空的位置,再将对象存放。另外,在最后还需要判断一下当前的存储的对象个数是否已经超出了阈值(threshold的值)大小,如果超出了,需要重新扩充并将所有的对象重新计算位置(rehash函数来实现)。那么我们看看rehash函数如何实现的:

private void rehash() {
    expungeStaleEntries();//青村掉废弃的实体
    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}
  /**
    * Expunge all stale entries in the table.
    */
    private void expungeStaleEntries() {
        Entry[] tab = table;
        int len = tab.length;
        for (int j = 0; j < len; j++) {
            Entry e = tab[j];
            //遍历tab数组,如果entry的key为null,则清除
            if (e != null && e.get() == null)
                expungeStaleEntry(j);
        }
    }
    //上面的get方法
    public T get() {
        return this.referent;
    }
    //具体的清除方法
    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        // expunge entry at staleSlot
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;
        // Rehash until we encounter null
        Entry e;
        int i;
        for (i = nextIndex(staleSlot, len);
            (e = tab[i]) != null;
            i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                int h = k.threadLocalHashCode & (len - 1);
                if (h != i) {
                    tab[i] = null;
                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                    tab[h] = e;
                }
            }
        }
        return i;
    }
看到,rehash函数里面先调用了expungeStaleEntries函数,然后再判断当前存储对象的大小是否超出了阈值的3/4。如果超出了,再扩容。看的有点混乱。为什么不直接扩容并重新摆放对象?为啥要搞成这么复杂?

其实,ThreadLocalMap里面存储的Entry对象本质上是一个WeakReference。也就是说,ThreadLocalMap里面存储的对象本质是一个对ThreadLocal对象的弱引用,该ThreadLocal随时可能会被回收!即导致ThreadLocalMap里面对应的Value的Key是null。我们需要把这样的Entry给清除掉,不要让它们占坑。

弱引用:用来描述非必需的对象。但是它的强度比软引用更弱一些。被弱引用关联的对象只能生存到下一次垃圾回收之前。当垃圾收集器工作时,无论当前内存是否足够,都会回收掉只被弱引用关联的对象。
expungeStaleEntries函数就是做这样的清理工作,清理完后,实际存储的对象数量自然会减少,这也不难理解后面的判断的约束条件为阈值的3/4,而不是阈值的大小。

那么如何判断哪些Entry是需要清理的呢?其实很简单,只需把ThreadLocalMap里面的key值遍历一遍,为null的直接删了即可。可是,前面我们说过,ThreadLocalMap并没有实现java.util.Map接口,即无法得到keySet。其实,不难发现,如果Key值为null,此时调用ThreadLocalMap的get(ThreadLocal)相当于get(null),get(null)返回的是null,这也就很好的解决了判断问题。也就是说,无需判断,直接根据get函数的返回值是不是null来判定需不需要将该Entry删除掉。注意,get返回null也有可能是key的值不为null,但是对于get返回为null的Entry,也没有占坑的必要,同样需要删掉,这么一来,就一举两得了。

获取Entry对象getEntry
Entry#getEntry()
/**
 * Get the entry associated with key.  This method
 * itself handles only the fast path: a direct hit of existing
 * key. It otherwise relays to getEntryAfterMiss.  This is
 * designed to maximize performance for direct hits, in part
 * by making this method readily inlinable.
 *
 * @param  key the thread local object
 * @return the entry associated with key, or null if no such
 */
private Entry getEntry(ThreadLocal key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}
getEntry函数很简单,直接通过哈希码计算位置i,然后把哈希表中对应i位置的Entry对象拿出来。如果对应位置的值为null,这就存在如下几种可能:

key对应的值确实为null
由于位置冲突,key对应的值存储的位置并不在i位置上,即i位置上的null并不属于key的值。
因此,需要一个函数再次去确认key对应的value的值,即getEntryAfterMiss函数:

故从中可以看出,ThreadLocalMap解决哈希冲突的方式并不是向HashMap一样采用链地址发,而是采用开放地址法。

/**
 * Version of getEntry method for use when key is not found in
 * its direct hash slot.
 *
 * @param  key the thread local object
 * @param  i the table index for key's hash code
 * @param  e the entry at table[i]
 * @return the entry associated with key, or null if no such
 */
private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    while (e != null) {
        ThreadLocal k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}
ThreadLocalMap.Entry对象
前面很多地方都在收ThreadLocalMap里面存储的是ThreadLocalMap.Entry对象,那么ThreadLocalMap.Entry对象到底是如何存储键值对的?同时又是如何做到对ThreadLocal对象进行弱引用?

先看看Entry类的源码:

static class Entry extends WeakReference<ThreadLocal> {
    /** The value associated with this ThreadLocal. */
    Object value;
    Entry(ThreadLocal k, Object v) {
        super(k);
        value = v;
    }
}
从源码的继承关系可以看到,Entry 是继承WeakReference。即Entry 本质上就是WeakReference,换言之,Entry就是一个弱引用,具体讲,Entry实例就是对ThreadLocal某个实例的弱引用。只不过,Entry同时还保存了value。

ThreadLocalRandom
即使对象是线程安全的,使用ThreadLocal也可以减少竞争,比如对于Random类来说,Random是线程安全的,但是如果并发访问竞争激烈的化,性能会下降,所以Java并发包提供了类TheadLocalRandom,它是Random的子类,利用了ThreadLocal,它没有public的构造方法,通过静态方法current获取对象。(jdk1.7)

public static void main(String[] args) {
    ThreadLocalRandom rnd = ThreadLocalRandom.current();
    System.out.println(rnd.nextInt());
}
current方法的实现为:
public static ThreadLocalRandom current() {
    return localRandom.get();
}
localRandom就是一个ThreadLocal变量:
private static final ThreadLocal<ThreadLocalRandom> localRandom =
    new ThreadLocal<ThreadLocalRandom>() {
        protected ThreadLocalRandom initialValue() {
            return new ThreadLocalRandom();
        }
};
线程池与ThreadLocal
线程池中的线程是会重用的,如果异步任务使用了ThreadLocal,会出现什么情况呢?
看一个简单的示例:

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public class ThreadPoolProblem {
    static ThreadLocal<AtomicInteger> sequencer = new ThreadLocal<AtomicInteger>(){
        @Override
        protected AtomicInteger initialValue() {
            return new AtomicInteger(0);
        }
    };
    static class Task implements Runnable{
        @Override
        public void run() {
            AtomicInteger s = sequencer.get();
            int initial = s.getAndIncrement();
            //期望初始值为0
            System.out.println(initial);
        }
    }
    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        executor.execute(new Task());
        executor.execute(new Task());
        executor.execute(new Task());
        executor.shutdown();
    }
}
对于异步任务Task而言,它期望的初始值应该为0,但运行程序,结果却为:
 

第三步执行异步任务,结果就不对了,为什么呢?因为线程池中的线程在执行完一个任务时,其中的ThreadLocal对象并不会清空,修改后的值带到了下一个异步的任务中。

解决思路:

第一次使用ThreadLocal对象时,总是先调用set设置初始值,或者TheadLocal重写了initialValue方法,先调用其remove方法
使用完ThreaLlocal对象后,总是调用其remove方法。
使用自定义线程池。
第一种:

static class Task implements Runnable{
        @Override
        public void run() {
            sequencer.set(new AtomicInteger(0));
            //或者sequencer.remove();
            AtomicInteger s = sequencer.get();
            int initial = s.getAndIncrement();
            //期望初始值为0
            System.out.println(initial);
        }
    }
第二种:

 static class Task implements Runnable{
        @Override
        public void run() {
            try{
                sequencer.set(new AtomicInteger(0));
                //或者sequencer.remove();
                AtomicInteger s = sequencer.get();
                int initial = s.getAndIncrement();
                //期望初始值为0
                System.out.println(initial);
            }finally {
                sequencer.remove();
            }
        }
    }
以上两种方法都比较麻烦,需要更改所有异步任务的代码,另一种方法是扩展线程池ThreadPoolExecutor,它有一个可扩展的方法:
protected void beforeExecute(Thread t,Runnable r){}

在线程池将任务r交个线程之前,会在线程t中先执行beforeExecute(),可以在这个方法中重新初始化ThreadLocal变量,可以显式初始化,如果不知道,也可以通过反射,重置所有的ThreadLocal。

小结
ThreadLocal使得每线程对同一个变量有自己的拷贝,是实现线程安全、减少竞争的一种方案。
ThreadLocal经常用于存储上下文信息,避免在不同代码间来回传递,简化代码。
每个线程都有一个Map,调用ThreadLocal对象的get/set实际就是以ThreadLocal对象键读当前线程的Map。
在线程池中使用ThreadLocal,需要注意,确保初始值是符合期望的。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值