来自一个学长大佬的分享,在此谢谢学长。
JavaSE Thread JUC
前言
以前虽然也接触过ThreadLocal,但是每当细问它的实现原理时,总是有一些地方说不清,归根到底还是对ThreadLocal的理解不够深入,只是停留在表面,今天就来总结一下ThreadLocal。
接下来的源码分析会用到,所以先放在前面温故一下。
简述
java.lang.ThreadLocal可用来存放线程的局部变量,每个线程都有单独的局部变量,彼此之间不会共享。它并不是一个Thread,而是一个线程的局部变量,也许把他命名为ThreadLocalVariable更合适
方法
public T get() //返回当前线程的局部变量
public void set(T value)
protected T initialValue() //返回当前线程的局部变量的初始值
public void remove()
这里只解释一下initialValue()方法,其他几个方法比较容易理解。 initialValue()方法为protected类型,它是为了被子类覆盖而特意提供的,该方法返回当前线程的局部变量的初始值。这个方法是一个延迟调用方法,当线程第一次调用ThreadLocal对象的get()或者set()方法时才执行,并且仅执行一次。在ThreadLocal类本身的实现中,initialValue()方法直接返回一个null,所以一般使用ThreadLocal,最好重写initialValue()方法,否则先调用set()方法会报空指针异常。
protected T initialValue() {
return null;
}
原理探究
我们看到ThreadLocal的类图(注:ThreadLocalMap是ThreadLocal的内部类,Entry是ThreadLocalMap的内部类。)
Thread有一个类型为ThreadLocalMap的变量threadLocals,用于存储该Thread拥有的用户变量,变量具体怎么存储交给了ThreadLocalMap,我们再看ThreadLocalMap,它有一个类型为Entry的数组变量table,用于存储用户变量,这里是数组,自然是可以存储多个用户变量了。一个Entry带代表了一个变量 一个用户变量又怎么存储的呢?我们还要再看看Entry,看他的构造函数就能一目了然。
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
其中,key类型为ThreadLocal,值就是用户设置的值。
set方法源码
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
getMap的源码
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
//Thread类中的属性
/* ThreadLocal values pertaining to this thread. This map is maintained by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
由于Thread类中ThreadLocal.ThreadLocalMap threadLocals没有修饰符,所以为default,即同包下可以使用。故在getMap中直接返回的t.threadLocals,因为ThreadLocal和Thread均在java.lang包下。
set方法的逻辑:当前Thread的threadLocals变量是否为null?,假设不为null,进入if,看到ThreadLocalMap#set方法.
ThreadLocalMap源码分析
存储结构
ThreadLocalMap中存储的是ThreadLocalMap.Entry(为了书写简单,后面直接写成Entry对象)对象。因此,在ThreadLocalMap中管理的也就是Entry对象。也就是说,ThreadLocalMap里面的大部分函数都是针对Entry的。
首先ThreadLocalMap需要一个“容器”来存储这些Entry对象,ThreadLocalMap中定义了Entry数组实例table,用于存储Entry。
private Entry[] table;
也就是说,ThreadLocalMap维护一张哈希表(一个数组),表里面存储Entry。既然是哈希表,那肯定就会涉及到加载因子,即当表里面存储的对象达到容量的多少百分比的时候需要扩容。ThreadLocalMap中定义了threshold属性,当表里存储的对象数量超过threshold就会扩容。
/**
* The next size value at which to resize.
*/
private int threshold; // Default to 0
/**
* Set the resize threshold to maintain at worst a 2/3 load factor.
*/
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
从上面代码看出,加载因子设置为2/3。即每次容量超过设定的len的2/3时,需要扩容。
存储Entry对象
首先看看数据是如何被放入到哈希表里面:
ThreadLocalMap#set()
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
从上面代码中看出,通过key(ThreadLocal类型)的hashcode来计算存储的索引位置i。如果i位置已经存储了对象,那么就往后挪一个位置依次类推,直到找到空的位置,再将对象存放。另外,在最后还需要判断一下当前的存储的对象个数是否已经超出了阈值(threshold的值)大小,如果超出了,需要重新扩充并将所有的对象重新计算位置(rehash函数来实现)。那么我们看看rehash函数如何实现的:
private void rehash() {
expungeStaleEntries();//青村掉废弃的实体
// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}
/**
* Expunge all stale entries in the table.
*/
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
//遍历tab数组,如果entry的key为null,则清除
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
//上面的get方法
public T get() {
return this.referent;
}
//具体的清除方法
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// expunge entry at staleSlot
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
// Rehash until we encounter null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
看到,rehash函数里面先调用了expungeStaleEntries函数,然后再判断当前存储对象的大小是否超出了阈值的3/4。如果超出了,再扩容。看的有点混乱。为什么不直接扩容并重新摆放对象?为啥要搞成这么复杂?
其实,ThreadLocalMap里面存储的Entry对象本质上是一个WeakReference。也就是说,ThreadLocalMap里面存储的对象本质是一个对ThreadLocal对象的弱引用,该ThreadLocal随时可能会被回收!即导致ThreadLocalMap里面对应的Value的Key是null。我们需要把这样的Entry给清除掉,不要让它们占坑。
弱引用:用来描述非必需的对象。但是它的强度比软引用更弱一些。被弱引用关联的对象只能生存到下一次垃圾回收之前。当垃圾收集器工作时,无论当前内存是否足够,都会回收掉只被弱引用关联的对象。
expungeStaleEntries函数就是做这样的清理工作,清理完后,实际存储的对象数量自然会减少,这也不难理解后面的判断的约束条件为阈值的3/4,而不是阈值的大小。
那么如何判断哪些Entry是需要清理的呢?其实很简单,只需把ThreadLocalMap里面的key值遍历一遍,为null的直接删了即可。可是,前面我们说过,ThreadLocalMap并没有实现java.util.Map接口,即无法得到keySet。其实,不难发现,如果Key值为null,此时调用ThreadLocalMap的get(ThreadLocal)相当于get(null),get(null)返回的是null,这也就很好的解决了判断问题。也就是说,无需判断,直接根据get函数的返回值是不是null来判定需不需要将该Entry删除掉。注意,get返回null也有可能是key的值不为null,但是对于get返回为null的Entry,也没有占坑的必要,同样需要删掉,这么一来,就一举两得了。
获取Entry对象getEntry
Entry#getEntry()
/**
* Get the entry associated with key. This method
* itself handles only the fast path: a direct hit of existing
* key. It otherwise relays to getEntryAfterMiss. This is
* designed to maximize performance for direct hits, in part
* by making this method readily inlinable.
*
* @param key the thread local object
* @return the entry associated with key, or null if no such
*/
private Entry getEntry(ThreadLocal key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
getEntry函数很简单,直接通过哈希码计算位置i,然后把哈希表中对应i位置的Entry对象拿出来。如果对应位置的值为null,这就存在如下几种可能:
key对应的值确实为null
由于位置冲突,key对应的值存储的位置并不在i位置上,即i位置上的null并不属于key的值。
因此,需要一个函数再次去确认key对应的value的值,即getEntryAfterMiss函数:
故从中可以看出,ThreadLocalMap解决哈希冲突的方式并不是向HashMap一样采用链地址发,而是采用开放地址法。
/**
* Version of getEntry method for use when key is not found in
* its direct hash slot.
*
* @param key the thread local object
* @param i the table index for key's hash code
* @param e the entry at table[i]
* @return the entry associated with key, or null if no such
*/
private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
ThreadLocalMap.Entry对象
前面很多地方都在收ThreadLocalMap里面存储的是ThreadLocalMap.Entry对象,那么ThreadLocalMap.Entry对象到底是如何存储键值对的?同时又是如何做到对ThreadLocal对象进行弱引用?
先看看Entry类的源码:
static class Entry extends WeakReference<ThreadLocal> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal k, Object v) {
super(k);
value = v;
}
}
从源码的继承关系可以看到,Entry 是继承WeakReference。即Entry 本质上就是WeakReference,换言之,Entry就是一个弱引用,具体讲,Entry实例就是对ThreadLocal某个实例的弱引用。只不过,Entry同时还保存了value。
ThreadLocalRandom
即使对象是线程安全的,使用ThreadLocal也可以减少竞争,比如对于Random类来说,Random是线程安全的,但是如果并发访问竞争激烈的化,性能会下降,所以Java并发包提供了类TheadLocalRandom,它是Random的子类,利用了ThreadLocal,它没有public的构造方法,通过静态方法current获取对象。(jdk1.7)
public static void main(String[] args) {
ThreadLocalRandom rnd = ThreadLocalRandom.current();
System.out.println(rnd.nextInt());
}
current方法的实现为:
public static ThreadLocalRandom current() {
return localRandom.get();
}
localRandom就是一个ThreadLocal变量:
private static final ThreadLocal<ThreadLocalRandom> localRandom =
new ThreadLocal<ThreadLocalRandom>() {
protected ThreadLocalRandom initialValue() {
return new ThreadLocalRandom();
}
};
线程池与ThreadLocal
线程池中的线程是会重用的,如果异步任务使用了ThreadLocal,会出现什么情况呢?
看一个简单的示例:
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
public class ThreadPoolProblem {
static ThreadLocal<AtomicInteger> sequencer = new ThreadLocal<AtomicInteger>(){
@Override
protected AtomicInteger initialValue() {
return new AtomicInteger(0);
}
};
static class Task implements Runnable{
@Override
public void run() {
AtomicInteger s = sequencer.get();
int initial = s.getAndIncrement();
//期望初始值为0
System.out.println(initial);
}
}
public static void main(String[] args) {
ExecutorService executor = Executors.newFixedThreadPool(2);
executor.execute(new Task());
executor.execute(new Task());
executor.execute(new Task());
executor.shutdown();
}
}
对于异步任务Task而言,它期望的初始值应该为0,但运行程序,结果却为:
第三步执行异步任务,结果就不对了,为什么呢?因为线程池中的线程在执行完一个任务时,其中的ThreadLocal对象并不会清空,修改后的值带到了下一个异步的任务中。
解决思路:
第一次使用ThreadLocal对象时,总是先调用set设置初始值,或者TheadLocal重写了initialValue方法,先调用其remove方法
使用完ThreaLlocal对象后,总是调用其remove方法。
使用自定义线程池。
第一种:
static class Task implements Runnable{
@Override
public void run() {
sequencer.set(new AtomicInteger(0));
//或者sequencer.remove();
AtomicInteger s = sequencer.get();
int initial = s.getAndIncrement();
//期望初始值为0
System.out.println(initial);
}
}
第二种:
static class Task implements Runnable{
@Override
public void run() {
try{
sequencer.set(new AtomicInteger(0));
//或者sequencer.remove();
AtomicInteger s = sequencer.get();
int initial = s.getAndIncrement();
//期望初始值为0
System.out.println(initial);
}finally {
sequencer.remove();
}
}
}
以上两种方法都比较麻烦,需要更改所有异步任务的代码,另一种方法是扩展线程池ThreadPoolExecutor,它有一个可扩展的方法:
protected void beforeExecute(Thread t,Runnable r){}
在线程池将任务r交个线程之前,会在线程t中先执行beforeExecute(),可以在这个方法中重新初始化ThreadLocal变量,可以显式初始化,如果不知道,也可以通过反射,重置所有的ThreadLocal。
小结
ThreadLocal使得每线程对同一个变量有自己的拷贝,是实现线程安全、减少竞争的一种方案。
ThreadLocal经常用于存储上下文信息,避免在不同代码间来回传递,简化代码。
每个线程都有一个Map,调用ThreadLocal对象的get/set实际就是以ThreadLocal对象键读当前线程的Map。
在线程池中使用ThreadLocal,需要注意,确保初始值是符合期望的。