ThreadLocal
上节我们简单使用过ThreadLocal,本节我们讨论ThreadLocal实现的原理,如何做到每个线程有一份属于自己的数据备份
ThreadLocal类图
从类图我们可以很清楚的看到 ThreadLocal 的内部结构 (JDK8_201)
- 1个无参构造函数
- 4个public函数 其中withInitial()为类级别的静态方法
- 2个内部类 ThreadLocalMap 与 SuppliedThreadLocal
ThreadLocal
接下来我们重点分析 get() set() 与remove()方法
get()
public T get() {
Thread t = Thread.currentThread(); //获取当前线程
ThreadLocalMap map = getMap(t); // 取得当前线程的ThreadLocalMap实例,就是Thread类的属性 threadLocals
if (map != null) {//map为空调用初始化方法
// 这里可以看到最后操作的就是ThreadLocalMap 实例
ThreadLocalMap.Entry e = map.getEntry(this);//通过ThreadLocalMap获取存储的值,具体后续分析
if (e != null) {//环形数组中没有找到对应数据也会调用初始化方法
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
//如果从ThreadLocalMap未找到 获取初始值
return setInitialValue();
}
ThreadLocalMap getMap(Thread t) {
//获取Thread类的属性 threadLocals 类型为ThreadLocal.ThreadLocalMap
return t.threadLocals; //threadLocals中保存线程的所有的线程本地变量
}
private T setInitialValue() {
T value = initialValue();//调用initialValue方法获取,此方法默认返回null,可重写
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);//设置初始值 操作的是ThreadLocalMap
else
createMap(t, value);//初始化ThreadLocalMap
return value; //返回初始值
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);//调用ThreadLocalMap构造函数 参考下方构造函数
}
set()
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value); //操作的是ThreadLocalMap
else
createMap(t, value);
}
remove()
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);//操作的是ThreadLocalMap
}
withInitial()
从 get() 方法我们可以看到,当ThreadLocal实例第一次调用get()时,如果ThreadLocalMap没有初始化或者环形数组中没有找到数据时,最终会调用ThreadLocal的initialValue()方法
所以我们可以重写initialValue提供初始值
ThreadLocal<Integer> threadLocal = new ThreadLocal(){
@Override
protected Object initialValue() {
return 0;
}
};
withInitial()方法就是这么做的
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);//返回一个SuppliedThreadLocal对象 SuppliedThreadLocal是一个内部类
}
//SuppliedThreadLocal对象内部继承ThreadLocal并重写了initialValue()方法
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);//不能为Null
}
@Override
protected T initialValue() {
/**
* 调用supplier接口的get()方法
* supplier属性通过构造函数传入不能为空
* Supplier是一个FunctionalInterface 支持lambada表达式
*/
return supplier.get();
}
}
所以我们还可以这样提供初始值
ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(()-> 1);
ThreadLocalMap
从上面可以看到对ThreadLocal的操作都会作用到ThreadLocalMap上,我们具体分析ThreadLocalMap
ThreadLocalMap提供了一种为ThreadLocal定制的高效实现,并且自带一种基于弱引用的垃圾清理机制。
类图
从图中我们可以看到ThreadLocalMap中有一个Entry数组,数组里面存放的是Entry对象,
Entry对象可以放两个属性一个是ThreadLocal对象(弱引用),另一个就是要保存的线程本地变量
ThreadLocalMap可以简单地将它的key视作ThreadLocal,value为代码中放入的值(实际上key并不是ThreadLocal本身,而是它的一个弱引用)
ThreadLocalMap构造函数
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];//初始化一个长度为16的Entry数组
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);//定位元素的下标
table[i] = new Entry(firstKey, firstValue);//把值放到数组对应位置
size = 1;
setThreshold(INITIAL_CAPACITY);//设置阀值为数组长度2/3,超过这个值就要扩容数组 再Hash
}
可以看到在第一次get或者set()的时候会初始化ThreadLocalMap
另外我们要注意的是ThreadLocalMap内部的Entry数组是环形数组
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0); //环形数组的下一个索引
}
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1); //环形数组的上一个索引
}
set()
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);//定位元素下标
for (Entry e = tab[i];
e != null; //循环结束条件1 元素为空
e = tab[i = nextIndex(i, len)]) { //从位置i一直向后线性探测寻找,只要元素不为空
ThreadLocal<?> k = e.get();
if (k == key) { //元素都不为空,找到对应的key,更新value 循环结束条件2
e.value = value;
return;
}
if (k == null) { //元素都不为空, 弱引用失效,调用replaceStaleEntry替换旧值 循环结束3
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);//循环条件1不满足,有空位值,把值放到第一个空位置上
int sz = ++size; //数量加1
if (!cleanSomeSlots(i, sz) && sz >= threshold) //未清理数据&数量大于阀值
rehash();
}
//替换失效位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
int slotToExpunge = staleSlot; //staleSlot就是弱引用失效的位置索引
//从要删除的位置向前找,找到最前一个弱引用失效的位置,把并索引赋给slotToExpunge
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null; //此处为空会结束循环
i = prevIndex(i, len))
if (e.get() == null) //失效位置
slotToExpunge = i; //赋值,备注:如果都不为空并且没有失效的位置,i就是当前位置
for (int i = nextIndex(staleSlot, len); //向后线性探测
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == key) { //找到有效的key
e.value = value; //更新key对应的value
//下面两句交换失效位置与有效key的位置元素
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
/*
* 如果在整个扫描过程中(包括函数一开始的向前扫描与i之前的向后扫描)
* 找到了之前的失效位置则以那个位置作为清理的起点,
* 否则则以当前的i作为清理起点,此时的i为交换位置后的失效位置索引
*/
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 从slotToExpunge开始做一次删除失效元素,再做一次清理失效位置
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// 如果当前的位置已经失效,并且向前扫描过程中没有失效位置,则更新slotToExpunge为当前位置索引
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 如果key在table中不存在,则在失效位置放一个替换掉失效元素
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 在探测过程中如果发现有任何失效位置,则做一次清理
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
//删除失效元素
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
tab[staleSlot].value = null;//失效位置value置为空
tab[staleSlot] = null; //失效位置为空
size--;
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) { //碰到失效位置
e.value = null;//失效位置value置为空
tab[i] = null;//失效位置为空
size--;
} else { //没有失效的位置
int h = k.threadLocalHashCode & (len - 1); //重新计算索引位置
if (h != i) { //如果与当前位置不相同
tab[i] = null;把当前位置置为空
while (tab[h] != null)//从h开始寻找下一个空位置
h = nextIndex(h, len);
tab[h] = e; 把当前元素放到空位置
}
}
}
//返回失效位置后第一个空位置索引
return i;
}
//清理失效位置 没有要清理的数据时返回false
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len); //从i开始向后连续探测
Entry e = tab[i];
if (e != null && e.get() == null) { //是否有失效的位置元素
n = len; //有失效元素重置n
removed = true;
i = expungeStaleEntry(i); //返回i后的第一个空索引位置
}
} while ( (n >>>= 1) != 0); //n 控制扫描次数 正常情况下如果log^n次扫描没有发现失效位置,函数就结束了
return removed;
}
//再Hash
private void rehash() {
expungeStaleEntries();//全面清理
// size会变小,因为清理会使size减少
if (size >= threshold - threshold / 4) //size只要大于3/4阀值,也就是len/2
resize();//扩容
}
//全面清理
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
//找到所有的失效位置,调用删除失效元素
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
//扩容
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2; //扩容2倍
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null) //从h开始寻找下一个空位置
h = nextIndex(h, newLen);
newTab[h] = e;//把值放到空位置
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
set()方法的过程
- 探测过程中元素都有效,并且顺利找到key所在的位置,直接替换即可
- 探测过程中发现有失效位置,调用replaceStaleEntry,效果是最终一定会把key和value放在这个位置,并且会尽可能清理无效元素
- 在replaceStaleEntry过程中,如果找到了key,value替换新值,把此元素与失效元素交换位置
- 在replaceStaleEntry过程中,没有找到key,直接在失效位置替换为新entry
- 探测没有发现key,有空位值,把值放到第一个空位置上,这也是线性探测法的一部分。做一次清理无效位置,如果没清理出去key,并且当前table大小已经超过阈值了,则做一次rehash,rehash函数会调用一次全量清理方法也即expungeStaleEntries,如果完了之后table大小超过了threshold - threshold / 4,则进行2倍扩容
get()
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);//定位数组标
Entry e = table[i];
if (e != null && e.get() == key) //元素未失效 key相同 返回数组当前位置元素
return e;
else //能进到这里说明此位置失效或者为空
return getEntryAfterMiss(key, i, e); //未找到继续线性探测
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
//只要不为空继续向后查找
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key) //找到相同key 返回
return e;
if (k == null) //位置失效
expungeStaleEntry(i); //删除失效位置 参考set()方法中的解释
else
i = nextIndex(i, len); //下一个索引
e = tab[i];
}
return null;
}
remove()
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear(); //断开引用
expungeStaleEntry(i);//删除失效位置
return;
}
}
}
//e.clear其实调用了父类 Reference 的方法
public void clear() {
//这个referent就是entry的key,也就是ThreadLocal实例
this.referent = null;
}