并发编程-ThreadLocal
本篇我们要了解的是在实际使用比较少的,但在很多应用框架和中间件的源码中经常出现的这么一个工具——ThreadLocal,它的核心作用是实现线程的隔离,本篇我们就来探究它的实现原理。
概念
在Java中,ThreadLocal 是一个用于提供线程内的局部变量。这种变量不同于它们的常规实例变量,因为每一个访问该变量的线程都有其自己的独立初始化的变量副本。ThreadLocal 实例通常用于保存线程上下文信息,这样信息就可以在线程内部方便地传递,而无需作为参数显式传递。
我们可以这么理解 ThreadLocal 的概念,它可以帮助我们为每个线程创建并存储变量副本,使得每个线程都能访问自己独立的变量,而不会与其他线程产生干扰。这就像是每个线程都有自己的“小仓库”,里面存放着它自己的东西,其他线程是看不到也拿不到的。
使用
概念总是相对抽象一些,理解起来比较片面。下面我们通过几个例子来诠释ThreadLocal的作用
-
在没有ThreadLocal的情况,会出现的问题
public class ThreadLocalExample { private static int num = 0; public static void main(String[] args) { Thread[] threads = new Thread[5]; for (int i=0;i<5;i++){ threads[i] = new Thread(()->{ num+=5; System.out.println(Thread.currentThread().getName()+" "+num); }); } for(int i=0;i<5;i++){ threads[i].start(); } } }
在上述示例中,我们定义了一个变量num,通过五个线程分别进行 num+5 的操作,理论上我们想要的效果是每个线程读取的num都是0,每个线程运算的结果也都5 ,然而实际运行后我们发现
问题出来了,每个线程读取到的num的值都不为0,能不能让每个线程读取到的值一样,我们用ThreadLocal来改良一下
-
用ThreadLocal来解决上面的问题
public class ThreadLocalExample { // 使用 ThreadLocal 存储每个线程的 num 副本 private static final ThreadLocal<Integer> threadLocalNum = ThreadLocal.withInitial(() -> 0); public static void main(String[] args) { Thread[] threads = new Thread[5]; for (int i = 0; i < 5; i++) { threads[i] = new Thread(() -> { // 获取当前线程的 num 副本,并增加 5 int num = threadLocalNum.get() + 5; // 设置当前线程的 num 副本为新的值 threadLocalNum.set(num); // 输出当前线程的名称和 num 值 System.out.println(Thread.currentThread().getName() + " " + num); }); } for (int i = 0; i < 5; i++) { threads[i].start(); } } }
在上面的代码中,我们创建了一个 ThreadLocal 实例 threadLocalNum,它会在每个线程第一次调用 get() 方法时初始化为 0。然后,在每个线程的执行体中,我们通过 threadLocalNum.get() 获取当前线程的 num 副本,将其增加 5,并通过 threadLocalNum.set(num) 将更新后的值设置回当前线程的 num 副本中。这样,每个线程都会操作自己的 num 副本,而不会互相干扰。
需求分析
在了解了ThreadLocal的基本使用之后,我们不妨来推导一下ThreadLocal它是如何实现隔离功能的,便于我们后续源码分析的时候能够加深理解
-
多个线程对一个共享变量进行set() 操作时,并没有做其他任何的处理,而是直接进行set。所以每一个线程应该有一个与ThreadLocal相关联的容器来存储共享变量的副本
-
调用ThreadLocal.get() 方法时,get()方法没有任何的参数,所以 这个容器中存储的数据可以通过ThreadLocal来区分。我们很容易就会联想到这个容器的存储结果应该<K,V> 结构也就是Map,ThreadLocal作为key
-
基于以上两点我们大致能猜到它的结构,如下图所示
源码分析
那ThreadLocal具体是怎么样进行存储,接下来我们将对源码进行分析,我们从set() 方法开始
set(T value)
public void set(T value) {
//获取当前线程
Thread t = Thread.currentThread();
/**
getMap(t) 表示通过当前线程获取到ThreadLocalMap
而ThreadLocalMap 是 线程t(Thread)的一个成员变量
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
ThreadLocal.ThreadLocalMap threadLocals = null;
*/
ThreadLocalMap map = getMap(t);
//如果map不为空则进行赋值
if (map != null)
//this代表当前的ThreadLocal也是说 在这个map中ThreadLocal 就是key
map.set(this, value);
else
//如果为空则创建一个ThreadLocalMap 并赋值
createMap(t, value);
}
createMap(t, value)
void createMap(Thread t, T firstValue) {
//核心就是如何初始化ThreadLocalMap
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
//我们来看ThreadLocalMap是如何初始化的
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
//创建了一个数组长度为16的 Entry 数组
table = new Entry[INITIAL_CAPACITY];
//计算数组下标的位置
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
//对table[i]进行赋值,放入Entry数组的指定位置
table[i] = new Entry(firstKey, firstValue);
//由于此时ThreadLocalMap中只有一个键值对,所以将size设置为1。
size = 1;
/**下一个扩容的阈值,细心的小伙伴会发现
* threshold = len * 2 / 3 这里的为什么这么计算是兼顾了空间与时间的最优解,
* 也就是说性能与存储容量的一个平衡
*/
setThreshold(INITIAL_CAPACITY);
}
我们回过头来看当 ThreadLocalMap不为空的时候是如何赋值的
map.set(this, value)
private void set(ThreadLocal<?> key, Object value) {
//获取 Entry 数组
Entry[] tab = table;
//获取数组长度
int len = tab.length;
//计算数组下标
int i = key.threadLocalHashCode & (len-1);
//从i开始一直遍历到数组最后一个Entry(这里用到了线程探索,稍后在详细解释,这里先关注主体逻辑)
for (Entry e = tab[i];
e != null;
//nextIndex(i, len) 计算的是i+1
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
//如果key相等,则覆盖value
if (k == key) {
e.value = value;
return;
}
/**如果k为null (这里涉及到了弱引用,这个位置的ThreadLocal可能已经被回收),
* 用新的key、value覆盖同时清理k==null的老数据
*/
if (k == null) {
//替换新值并清理无效的key
replaceStaleEntry(key, value, i);
return;
}
}
//不存在无效的key,也不存在数组下标相同并且key相同的,在该数组下标位置插入数据
tab[i] = new Entry(key, value);
//统计数组元素个数
int sz = ++size;
//判断是否超过阈值
if (!cleanSomeSlots(i, sz) && sz >= threshold)
//扩容方法
rehash();
}
到这里其实整个set()方法的核心逻辑我们就已经梳理完了,大逻辑层面比较简单就是寻址和赋值,在我们展开细节之前我们得补充两个概念——线性探索和弱引用
线性探索
这里用到的线性探索是为了解决hash冲突,它是一种开放寻址策略。问题出现是当一个key通过hash函数计算出一个下标位置,而这个位置已经存在别的key,为了解决这个问题我们通过线性探索一直往下找直到找到离存在key最近的空闲位置,并把新的key插入到这个位置。结合我们之前看到的源码
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
//代码省略。。。。。。。。
}
- 通过key来计算出数组下标i
- 如果tab[i] != null ,通过循环从i开始继续查找,如果 i 这个位置已经存在其他的value,并且该位置的key和当前的key不相等,则继续往下判断 i+1 的位置
- 如果key计算的 i 的位置上的 key == null,清理并赋值新的数据
弱引用
在Java中,引用关系决定了对象的生存周期。Java提供了四种类型的引用:强引用(Strong Reference)、软引用(Soft Reference)、弱引用(Weak Reference)和虚引用(Phantom Reference)这里我们先理解强引用和弱引用,感兴趣的小伙伴可自行查阅相关资料
-
强引用(Strong Reference)
强引用是最普遍的一种引用关系。如果一个对象具有强引用,那么垃圾收集器就永远不会回收它,即使系统内存空间不足导致抛出 OutOfMemoryError 错误,使程序异常终止,也不会回收这种引用所指向的对象。我们通过一段代码示例来演示下强引用
public class StrongReferenceDemo { private static Object obj = new Object(); public static void main(String[] args) { Object strongObj = obj; obj = null; System.gc(); System.out.println("GC回收之后"+strongObj); } }
在上面的代码中,obj 就是一个强引用,它指向一个新的Object 实例。只要 obj 存在并且没有被赋值为 null ,这个Object 实例就不会被垃圾收集器回收。所以我们打印出 strongObj 的引用,它仍然有效,并指向那个 Object 实例。运行结果如下图所示
-
弱引用(Weak Reference)
弱引用是用来描述一些非必需对象的引用,被弱引用关联的对象只能生存到下一次垃圾收集发生之前。当垃圾收集器工作时,无论当前内存是否足够,都会回收只被弱引用关联的对象。我们通过一段代码来演示弱引用
public class WeakReferenceDemo { private static Object obj = new Object(); public static void main(String[] args) { WeakReference<Object> weakReference = new WeakReference<>(obj); obj = null; System.gc(); System.out.println("GC回收之后"+weakReference.get()); } }
在上面的代码中,我们创建了一个 Object 实例,并通过WeakReference 创建了一个弱引用 weakRef。当我们清除强引用 obj 后,在下一次垃圾收集发生时,Object 实例就可能被回收,此时通过 weakRef.get() 会返回 null。运行结果如下图所示
通过以上概念的铺垫,我们知道了在set() 方法中 k == null 的判断表明 k 已经被 JVM 回收 也就是说 k 是弱引用。总的来说,当 JVM 进行垃圾回收时,无论内存是否充足,都会回收弱引用关联的对象。 k == null 中 k为什么为空 我们就已经知道了。我们还有个问题当 k == null 还要执行 replaceStaleEntry(key, value, i)方法即清理并覆盖,那这又如何做的呢,我们详细来看 replaceStaleEntry(key, value, i) 这个方法
replaceStaleEntry(key, value, i)
/**
* 初步看到这段代码可能看着比较长,阅读不方便,我们先梳理一下这个方法要干什么
* 1.当k == null 并且Entry不为空的情况下,说明 key 指向的对象已经被回收了,所以要清理这个 key
* 2.还要将新的 key 和 value 进行赋值
*/
private void replaceStaleEntry(ThreadLocal<?> key, Object value,int staleSlot) {
//获取 Entry数组
Entry[] tab = table;
//获取数组长度
int len = tab.length;
Entry e;
// staleSlot 就是当前k == null 的下标
int slotToExpunge = staleSlot;
/**
* prevIndex(staleSlot, len) 是当前位置向前查找
* 该方法的 核心是 i-1
*/
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
//通过循环遍历找到前面那个无效的位置
slotToExpunge = i;
//从当前k == null 的下标向后查找
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == key) {
//如果key 相等更新value值
e.value = value;
//然后进行换位
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
//判断前面那个无效的slotToExpunge 和当前的staleSlot是否相等,若相等赋值i并进行一次清理
if (slotToExpunge == staleSlot)
slotToExpunge = i;
//清理动作
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
//如果key对应的value在Entry中不存在,则直接放入一个新的Entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
//这里是统一做一次清理
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
这里有些小伙伴可能看不太明白,我们详细展开来说,我们需要明确两点
- replaceStaleEntry(key, value, i) 这个方法中有两个 for 循环,一个是向前查找,一个是向后查找。
- cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); 这个方法我们先不管细节是如何处理的,我们统一认为它是一个清理动作。
在明确了以上两点之后,我们开始分析查找过程中最关键的逻辑,在 key 查找过程中分四种情况
-
向前查找有 key == null 的 Entry, 向后查找有 k == key 的可覆盖 value 的 Entry
第一个 for 循环代表从当前位置向前查找 ,查找截止的条件是(e = tab[i]) != null 也就是 Entry 为 null ,如果 e.get() == null 也就是 key == null 则表示找到了 key 为 null 的 Entry 。并且记录当前的位置 slotToExpunge = i ,由于是循环操作,所以能找到最前面那个 key 为 null 的 Entry。
if (e.get() == null) slotToExpunge = i;
第二个 for 循环代表从当前位置向后查找,查找的截止条件是也是 Entry 为 null 当 k == key 时则说明向后查找找到了可以覆盖value 值的 Entry 并覆盖 value。
if (k == key) { e.value = value;
注意,这里接下来的操作需要进行换位,也就是把查找到的可覆盖的 Entry 与当前 Entry 的进行交换
tab[i] = tab[staleSlot]; tab[staleSlot] = e;
我们还是通过图形来理解下这个过程
此时 slotToExpunge != staleSlot 执行一次清理,从slotToExpunge位置开始
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
-
向前查找有 key == null 的 Entry, 向后查找没有可覆盖 value 的 Entry
在理解了第一种情况之后,这第二种情况就很好理解了,向前查找 key 为 null 的 Entry 。并且记录当前的位置 slotToExpunge = i 由于向后没有找到可以覆盖 value 的 Entry,则直接覆盖当前 staleSlot 位置的 Entry 也就是执行代码
tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value);
转化为图形
最后 判断 slotToExpunge 与 staleSlot 是否相等,不相等进行一次清理
if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
-
向前查找没有 key == null 的 Entry, 向后查找有 k == key 的可覆盖 value 的 Entry
在没有进行任何查找操作之前,当前位置和向前查找的位置是一样的
int slotToExpunge = staleSlot;
向前查找没有 key == null 的 Entry 证明第一个 for 循环中不会执行
if (e.get() == null) slotToExpunge = i;
而向后查找到有 k == key 的可覆盖 value 的 Entry 依旧进行覆盖和换位操作。 此时 slotToExpunge == staleSlot 并且查找到的可以替换的 Entry 已经换位完成,所以 key == null 将这个位置赋值给 slotToExpunge 。再从slotToExpunge 位置开始进行清理
if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return;
转化为图形
-
向前查找没有 key == null 的 Entry, 向后查找没有 k == key 的可覆盖 value 的 Entry
如果前后查找都没有满足条件可覆盖的 Entry,则直接在当前位置 new Entry(key,value) 进行赋值。
tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value);
转化为图形
至此,整个查找的逻辑我们就分析完了,这个线性探索的设计还是比较有意思的,初次看时比较难以理解。多根据图形去思考场景就能捋顺查找流程了。趁热打铁,我们继续看具体是怎么进行清理的
expungeStaleEntry(int staleSlot)
//清理方法比较简单,就是从指定位置往下查找到 key == null
//并且将 key 和 value 都赋值为 null
private int expungeStaleEntry(int staleSlot) {
//获取到 Entry 数组
Entry[] tab = table;
//获取数组长度
int len = tab.length;
//把当前位置也就是开始查找的位置先进行清理
tab[staleSlot].value = null;
tab[staleSlot] = null;
//数组元素数量减 1
size--;
Entry e;
int i;
//从当前位置继续向后查找
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//查找到为key 为null 的同样进行清理 赋值为null
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//这里else 是判断那些特殊情况下引用还未完全清除的ThreadLocal
//重新计算下标并比较,不相等的则赋值为null
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
完整的set() 方法到这我们就分析完了,最后我们看一下get() 方法
get()
//get方法就比较简单了,就是根据当前的ThreadLocal去取值
public T get() {
//获取到当前线程
Thread t = Thread.currentThread();
//获取到当前线程的 ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
//this 代表当前的 ThreadLocal
//通过 ThreadLocal 找到ThreadLocalMap 对应的Entry 对象
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
//取出 Entry 对象的 value 值并返回
T result = (T)e.value;
return result;
}
}
//如果 ThreadLocalMap 为空则去设置一个初始值
return setInitialValue();
}
setInitialValue()
private T setInitialValue() {
//获取初始化的值,这个我们在使用示例时定义的初始值
T value = initialValue();
//剩余代码我们就不做解释了,跟 set() 方法一样
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
关于内存泄漏
在ThreadLocal使用不当的情况下会导致内存泄漏,ThreadLocalMap中的key采用弱引用,在ThreadLocal实例被回收之后,可以引用指向为null,虽然有replaceStaleEntry 中有线性探索及清理,但在极端情况下,会存在key == null 但是无法被探索到的可能性,从而导致内存泄漏。那要如何补救呢?总的来说,在使用 ThreadLocal 的地方,每个线程用完后,最终需要调用 remove() 方法防止出现内存泄漏
remove()
private void remove(ThreadLocal<?> key) {
//获取entry数组
Entry[] tab = table;
//获取数组长度
int len = tab.length;
//计算Entry 数组下标
int i = key.threadLocalHashCode & (len-1);
//从i的位置向后查找
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
//如果找到key相等的entry 则调用clear() 清除该entry持有的ThreadLocal引用和值
if (e.get() == key) {
e.clear();
//用于清理可能存在的无效的entry
expungeStaleEntry(i);
return;
}
}
}
总结
本篇我们对ThreadLocal进行了深入的分析,我们知道了它是利用了ThreadLocalMap实现的线程隔离。通过源码分析我们也了解它是如何通过线性探索去发现无效的且需要清理的对象。虽然平常使用ThreadLocal比较少,但是它的实现思想还是值得借鉴和学习的。