前言
之前看ThreadLocal原理基本是博客,但对这个还是一知半解,趁着这几天有空看了一遍,印象深刻了很多。同时发现新大陆,原来ThreadLocal在进行set、get等操作,都会有槽位清理的逻辑,来防止内存泄漏,这也是之前一直没有关注的地方。
在看之前,希望大家先花亿分钟打开ThreadLocal的源码,跟着来一步一步的分析。
一、ThreadLocal介绍
1、ThreadLocal基本方法
ThreadLocal主要包含一个静态方法withInitial()
,和三个基本实例方法set()
、get()
、remove()
;
大家可能对withInitial()
方法会比较陌生,下面是这个方法的代码:
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
@Override
protected T initialValue() {
return supplier.get();
}
}
withInitial()
使用了java 1.8增加的函数式接口Supplier
,Supplier
接口通常用于延迟计算,即在需要值的时候才进行计算,它提供了一种延迟执行的方式。
withInitial()
其实就相当于new Thread()
,并且给set值,使用起来就类似于这样:
ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(()->100);
//相当于-------------------------------------------------
ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
threadLocal.set(100);
2、ThreadLocal的哈希值
每个Threadlocal哈希值是通过调用nextHashCode()
方法生成的,最终调用的是AtomicInteger中的getAndAdd
方法,保证自增的原子性。
每当创建ThreadLocal实例时这个值都会累加 0x61c88647
,0x61c88647
是散列算法中常用的一个魔法值,用于将哈希码能均匀分布在2的N次方的数组里,降低冲突几率。
虽然0x61c88647
是一个比较大的值,但是即使AtomicInteger超出范围变为负数,也不影响计算索引位置,因为只用到了与运算。
最终得出的threadLocalHashCode
作为创建ThreadLocalMap实例化对象和计算ThreadLocalMap索引位置使用。
private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode = new AtomicInteger();
/**
* 哈希码累加值
*/
private static final int HASH_INCREMENT = 0x61c88647;
/**
* 返回HashCode
*/
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
二、ThreadLocal内部结构
1、ThreadLocalMap
ThreadLocal主要结构是ThreadLocalMap,是ThreadLocal的静态内部类。但它不是基于Hashmap实现的,而是一个Entry数组,每一个Entry则是一个key-value元素。
Entry的key为ThreadLocal,并且为弱引用,value则为Object。
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
为什么key要设置为WeakReference弱引用呢?
这个在面试题上已经烂大街了。。。直接上传送门:https://juejin.cn/post/7126708538440679460。
2、初始化
初始化步骤也不难:先通过threadLocalHashCode
和INITIAL_CAPACITY
(初始容量为16),算出在table(Entry数组)的位置,再new一个Entry赋值。
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1; //当前的实际Entry数量
setThreshold(INITIAL_CAPACITY);
}
最后一行setThreshold方法,代码就简单一句:threshold = len * 2 / 3
,其实是计算当前需要扩容的阈值,这里表示的是达到容量的三分之二就要扩容了。
3、set方法
这个set方法跟初始化差不多,也是先算出索引位置,再向tab[i]
赋值。
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
int len = tab.length; // 获取数组的长度
int i = key.threadLocalHashCode & (len - 1); // 计算初始索引
// 遍历数组,查找匹配的ThreadLocal
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
if (e.refersTo(key)) {
// 如果找到匹配的ThreadLocal,更新其值并返回
e.value = value;
return;
}
if (e.refersTo(null)) {
// 如果遇到过期的ThreadLocal,替换它并返回
replaceStaleEntry(key, value, i);
return;
}
}
// 如果没有匹配的ThreadLocal,创建新的Entry并插入
tab[i] = new Entry(key, value);
int sz = ++size; // 当前的实际Entry数量增加1
// 进行重新散列
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
其中有一段for循环,通过循环遍历表中的槽位,查找是否存在相同键的Entry:
(1)如果找到相同键的Entry,更新其值为新值。
(2)如果找到一个空槽位(键为null),则替换为空槽位。
(3)如果循环结束仍未找到相同键的Entry,则在表中的当前位置创建一个新的Entry对象。
末尾的if (!cleanSomeSlots(i, sz) && sz >= threshold)
检查了是否有清理任何槽位,并且映射的大小大于或等于阈值threshold。如果这两个条件都为真,说明映射需要定期调整大小和重新散列,以保持其性能。
4、replaceStaleEntry替换槽位并清理
感觉看下来,最复杂就是replaceStaleEntry
方法,刚开始看一脸懵。 replaceStaleEntry
出现在set方法中,用来替换键为null的Entry,当e.refersTo(null)
时,会进入replaceStaleEntry
方法。
大致过程可以分为两步:
1、首先开始向前检索key为null的Entry,直到tab[i]为null停止,记录当前索引位置并赋值给slotToExpunge
2、开始向后检索元素,分为两种情况:
(1)找到匹配的key的Entry,更新值后,与tab[staleSlot]
进行交换,并清理槽位后返回。注意的是tab[staleSlot]
是一个key为null的Entry。
(2)检索后没有找到匹配,这时候就要在tab[staleSlot]
新增一个Entry了,并清理slotToExpunge
到len
范围内的槽位。
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
int len = tab.length; // 获取数组的长度
Entry e;
int slotToExpunge = staleSlot; // 初始化要清理的槽位为staleSlot
// 从staleSlot往前查找过期Entry
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.refersTo(null))
slotToExpunge = i; // 更新要清理的槽位为找到的过期Entry的位置
// 从staleSlot往后查找Entry
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
// 如果找到匹配的ThreadLocal键的Entry
if (e.refersTo(key)) {
e.value = value; // 更新值
//将找到的Entry与staleSlot位置的Entry交换
//相当于匹配的元素往前移,将key为null元素往后移
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 如果前一个过期槽位存在,从前一个过期槽位开始清理
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 清理一些槽位,然后返回
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// 如果在向后扫描中未找到过期Entry,
// slotToExpunge则取扫描键时看到的第一个key为null的Entry索引,为了后面的cleanSomeSlots清理
if (e.refersTo(null) && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 如果未找到匹配的ThreadLocal键的Entry,将新Entry放入staleSlot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 如果运行中还有其他过期Entry,会清理它们
if (slotToExpunge != staleSlot)
//从slotToExpunge下标索引开始清理槽位
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
5、get方法
get方法很简单:
1、直接查找元素,找到直接返回
2、第一步没有找到,再利用线性探索进行查找Entry,直到匹配Entry。期间遇到key为null的Entry,顺便清理。
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1); // 计算初始索引
Entry e = table[i]; // 获取索引处的Entry
if (e != null && e.refersTo(key))
return e; // 如果找到匹配的Entry,直接返回
else
return getEntryAfterMiss(key, i, e); // 否则调用getEntryAfterMiss方法进行进一步处理
}
// 未找到匹配的ThreadLocal键时的处理
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
int len = tab.length; // 获取数组的长度,即映射的容量
while (e != null) {
if (e.refersTo(key))
return e; // 如果找到匹配的Entry,直接返回
if (e.refersTo(null))
expungeStaleEntry(i); // 如果找到引用为null的过期Entry,则清理
else
i = nextIndex(i, len); // 否则计算下一个索引
e = tab[i]; // 获取新索引处的Entry
}
return null; // 如果循环结束仍未找到匹配的Entry,返回null表示未找到
}
6、remove方法
remove步骤如下:
1、获取entry索引位置i,如果tab[i]
与key不相等,则继续进行线性探测,直到找到与key相等的元素Entry。
2、找到元素后,调用clear方法进行清理,并且调用expungeStaleEntry
,从数组中删除任何过期的Entry和进行哈希调整。
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.refersTo(key)) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
7、cleanSomeSlots和expungeStaleEntry清理槽位
cleanSomeSlots
会检查一部分数据进行清理,内部实际是调用expungeStaleEntry
。
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false; // 用于标记是否有Entry被移除
Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
int len = tab.length; // 获取数组的长度
do {
i = nextIndex(i, len); // 计算下一个索引
Entry e = tab[i]; // 获取当前索引位置的Entry
if (e != null && e.refersTo(null)) {
n = len; // 如果找到符合条件的Entry,则重新设置n为数组长度
removed = true; // 设置标记为true,表示有Entry被移除
i = expungeStaleEntry(i); // 清理过期的Entry,并返回下一个有效索引
}
} while ((n >>>= 1) != 0); // 继续循环,直到n变为0,比如n为16,则进行4次循环
return removed; // 返回标记是否有Entry被移除
}
expungeStaleEntry
方法有三个作用:
1、及时清理key为null的Entry
2、重新计算Entry的索引位置,调整后如果遇到哈希冲突,则调用nextIndex
进行线性探测,直到获取空槽位索引
3、返回最后一个有效的索引(空槽位)
可以看到这个清理的过程只是覆盖了一段范围,并不是全部区间。
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
int len = tab.length; // 获取数组的长度
// 清理过期Entry
tab[staleSlot].value = null; // 将过期槽位的值设为null
tab[staleSlot] = null; // 将过期槽位设为null
size--; // 减小映射的大小
// 重新散列直到遇到null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
// 如果ThreadLocal为null,清理该槽位的Entry
e.value = null;
tab[i] = null;
size--; // 减小映射的大小
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
// 如果计算的哈希码与当前索引不同,说明需要重新散列
tab[i] = null;
// 重新哈希后,可能会遇到哈希冲突,使用线性探索法获取空槽位
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e; // 在新的位置存储Entry
}
}
}
return i; // 返回最后一个有效的索引(空槽位)
}
8、rehash再哈希
1、先扫描全表,清理key为null的Entry
2、遍历旧数组的每个Entry,计算新的哈希码,并将新位置Entry储存到新数组。
3、再哈希完成后,最后将新Entry[]替换旧Entry[]数组。
private void rehash() {
expungeStaleEntries(); // 清理过期的Entry
// 为了避免滞后,使用较小的阈值进行加倍
if (size >= threshold - threshold / 4)
resize(); // 调整映射的大小
}
private void resize() {
Entry[] oldTab = table; // 获取旧的Entry数组
int oldLen = oldTab.length; // 获取旧数组的长度
int newLen = oldLen * 2; // 计算新数组的长度为旧数组的两倍
Entry[] newTab = new Entry[newLen]; // 创建新的Entry数组
int count = 0; // 计数非空的Entry
// 遍历旧数组的每个Entry
for (Entry e : oldTab) {
if (e != null) {
ThreadLocal<?> k = e.get(); // 获取Entry中的ThreadLocal键
if (k == null) {
e.value = null; // 如果键为null,帮助垃圾回收
} else {
int h = k.threadLocalHashCode & (newLen - 1); // 计算新的哈希码
while (newTab[h] != null)
h = nextIndex(h, newLen); // 处理哈希冲突,找到新的位置
newTab[h] = e; // 在新位置存储Entry
count++; // 非空Entry计数加一
}
}
}
setThreshold(newLen); // 设置新的阈值
size = count; // 更新映射的大小
table = newTab; // 将映射的Entry数组指向新的数组
}
虚拟线程
因为用的是jdk21的最新版本,所以ThreadLocal出现虚拟线程的影子。
虚拟线程的目标是提供一种轻量级的线程模型,以便更好地支持大规模并发。与传统的本地线程(Native Thread)相比,虚拟线程更轻量,创建和销毁的成本更低,并且更容易扩展。其他的就不展开讲了,大家可以自行尝尝鲜。
下图代码主要判断当前是否虚拟线程:
结束
点个赞点个关注再走啦~