引言
ThreadLocal应该算是一个面试中问的比较多的类,这个类主要用途在于只使用一个对象的引用,便能够获取在多个线程中储存变量的副本,下面,便来详细剖析下这个类的实现原理,及其中需要注意到的一些地方。
对象的创建初始化
private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode =
new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
从上面可以看到,threadLocalHashCode保存了当前对象的hashcode,而hashcode是由nextHashCode()方法获取的,每次在原始值上增加了一个HASH_INCREMENT,之所以使用这个值是因为这样做能减少散列表的冲突。那么这个hashcode在哪里使用呢,请看接下来的代码。
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
源码解析
如图,在调用ThreadLocal里set()方法时候,如果当前Thread内部的threadLocals对象没有创建,则会new一个ThreadLocalMap对象,然后将当前ThreadLocal对象的内部的hashcode,对map的大小取余,然后放入map中。这样一看似乎很让人疑惑,既然一个线程中只有一个ThreadLocal的值,为什么非要用map进行存储呢,答案是,在一个线程中,可以new多个ThreadLocal对象,用它们保存不同的值,既然如此,为了快速找到某个ThreadLocal对象内存储的值,自然就使用散列表降低时间复杂度了。
接下来来看看内部类ThreadLocalMap的结构
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
private static final int INITIAL_CAPACITY = 16;
private Entry[] table;
private int size = 0;
private int threshold; // Default to 0
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
如图,此类继承了WeakReference类,也就是弱引用,它的特点是在JVM进行GC时,会专门有一条守护线程对这个类及其子类进行回收,对于弱引用的对象,每次GC当其他地方没有对它的引用时,无论内存空间足够充足,都会被进行回收。这里把ThreadLocal对象作为Entity对象的key,而set的值作为value,key弱引用对象,则进行垃圾回收之时,就会出现Entity对象存在,而内部key为null的结果,做出这种设计的主要原因还是为了帮助gc。
而从内部方法可以看出,保存value的是一个数组,且在数组内进行下标遍历之时采用的一个循环操作。其中threshold为下次扩容时的大小,即为数组size * factor,默认初始化为16,扩容因子为3分之2.
接下来看看private的set方法,看看具体如何实现的。
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 当前threadLocal在数组中应该处在的位置
int i = key.threadLocalHashCode & (len-1);
// 当前元素不为null时候进行遍历
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 找到了直接设置当前值,返回
if (k == key) {
e.value = value;
return;
}
// 若key为null,说明当前线程内有其他的ThreadValue对象的值被垃圾回收了,已无外部其他途径能访问到该value
// 对key及value进行替换
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 这个时候i为数组中为null的位置,放入
tab[i] = new Entry(key, value);
int sz = ++size;
// 清理被垃圾回收了key的entity对象,若没有清理成功,且size超过了最小扩容大小
// 进行扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
如图,本方法主要是以hash值进行散列表内的定位,若是发生了hash冲突,但冲突的ThreadLocal对象的key已经被垃圾回收时,对entity对象的key及value进行替换,这种情况举个例子,在一个方法里声明了一个局部变量ThreadLocal,进行了值的设定,当方法执行完成后,ThreadLocal变量已经无法进行访问,此时唯有ThreadLocalMap里存有对它的弱引用,但此时内部value已经无法通过外界进行访问了,则gc时会被回收,防止内存泄露。若是已经设置过值了,则直接重新设为最新值就行。
若是跳出循环,则在寻找过程中未有GC对象,且没有set过值,当前i值为hash位置后第一个空槽,在这里放入对象。
之后又一步检测,设计到了cleanSomeSlots方法,这个方法主要是清除被回收了key的entity对象,返回值为是否被回收,若是没有回收,且当前大小达到了扩容容量,则进行重新hash扩容。下面来看看这两个方法的具体实现。
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
// 进行lgn次的扫描,清理一些零散区段的key被gc回收的entity
// 若扫描到了,则n又会被重置为数组长度,继续扫描下去,直到回收完成
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
// 返回是否进行了entity的移除
return removed;
}
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 把当前entity位置设为null帮助gc
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
// 一直遍历到一个为null的槽
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
// 遍历过程中,若遇见了其他key被回收的entity,设为null,腾出空位并帮助垃圾回收
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// 若为存在key的实体,则计算当前hash值
int h = k.threadLocalHashCode & (len - 1);
// 如果该实体不是保存在其hash值应放的位置,则从应放的位置开始,寻找到下一个空位,避免get或set改值时,前面有key为null的entity被回收了,导致get获取到null值或者set对值进行重复设置
// 并将该实体挪到其中
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
// 返回下标为i的null槽
return i;
}
cleanSomeSlots()方法对底层数组进行了至少lgn次的扫描来进行key为null的entity的回收,key为null情况之前也说过,是外部已经没有那个key,也就是当前线程内ThreadLocal对象的引用的时候,会在gc时进行回收,n为数组长度,之所以说是至少,是因为若是回收成功了,会重置循环次数。这里涉及到了一个expungeStaleEntry()方法,这个方法参数为当前key为null的entity下标,方法内部对于被回收的entity对象的槽进行置null处理,并对中途碰到的没有回收key的非空槽进行了重新hash处理,因为不重新hash,很可能会当有key为null的entity被回收了,然后槽被置为null,这个时候get,set方法会出现问题。打个比方,若当前ThreadLocal对象的hashcode定位在数组中下标为1,但是下标为1,2的位置都被占用了,则只能放到3中,后来下标1,2的槽都被进行了回收,置为null,这个时候调用get(),就会出现1为null的情况,便会返回null值,调用set(),则会出现槽1,3中各有一个值的情况,很明显是错的。
最后方法返回了一个为null的槽,然后获取其下一个位置再次循环进行清除。
接下来让我们看看rehash()方法。源码如下
private void rehash() {
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
// 双倍扩容
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
// 保险起见,再次判定,进行key被GC的entity的清除
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
// 拉链法放置
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
// 循环整个数组,清除所有被GC的entity
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
首先在rehash()方法中调用了expungeStaleEntries()方法,点进去可以看到,这个方法作用很明显,循环这个数组,把所有的被回收了key的entity清除。然后调用了resize() 方法,这个方法对当前数组进行了双倍扩容,然后使用了setThreshold()方法,threshold值设置为下次扩容时的size大小,也就是数组大小乘上扩容因子。
在扩容中又保险起见进行了对被回收的entity置空的操作,并且使用拉链法,也就是冲突了则放入下一个null槽的方法,对原先整个数组进行了重新hash。
实际上因为map本身是附属于一个thread,对map的所有操作均不会存在线程安全问题,唯一可能有问题的ThreadLocal的hashcode也使用了CAS进行运算,因此我们完全可以安心使用它,无需担心线程安全问题。
set()方法的整个流程我们已经看完了,接下来看看get方法是如何实现了,代码如下。
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
在get()方法中,若是还没有设置值,则会调用setInitialValue()方法,若没初始化map,则会初始化ThreadLocalMap,这里面调用了initialValue()方法,这个方法本身是一个空方法,作用是我们可以重写这个方法,然后若不调用set方法,直接使用get就能够从threadLocal中拿出我们的初始值,当然以后调用了set方法也会将该值覆盖。这里面有一个getEntity()方法,让我们看下它是如何获取到当前线程的ThreadLocal对象保存的值。
代码如下
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
其中从传入的当前线程中取出了ThreadLocalMap,然后计算当前ThreadLocal对象的hashCode,如果没有发生冲突,则直接返回,否则调用getEntryAfterMiss()方法进行遍历。
getEntryAfterMiss()中一直遍历数组,直到槽为null或找到停止,其中又调用expungeStaleEntry方法对遍历中的遇到的被回收了的Entity进行清除,已经冲突hashcode节点进行重新hash处理。这里大家会有一个问题,如果get的节点正好冲突了,然后被重新hash处理了,那么是不是找不到了。
实际上,我们当前找的节点本来就是被冲突的节点,并且扫过的节点都是实际存在,未被gc的节点,因此当发生gc后,当前循环所走的后面的节点必然有我们要找的节点。
最后看下remove()方法。
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
方法很简单,也就是一路找到需要remove的节点,遍历路上顺带进行被gc节点的清除而已。
总结
- 对于ThreadLocal来说,它的本质是依赖于当前线程内的一个内部对象ThreadLocalMap来进行值的保存,因此不同线程内有不同的map,并且没有任何交集,不会有线程安全问题。
- 在set,get,remove方法中都有涉及到被GC节点的清除工作,所以基本上不可能会内存泄露严重。
- 之所以使用map进行值的保存,是因为考虑到一个线程中可能会使用多可ThreadLocal实例进行值的保存。
- map使用的散列方法是在冲突后,寻找后面第一个为null的槽,进行放入。
- map的生命周期与thread相绑定,因此对于内部保存的所有节点的value及key,都具有引用,但是key是弱引用对象,若在外界已不再持有key,即ThreadLocal对象,则会在下次gc时被垃圾回收掉,但是由于entity及内部value还存在,所以需要在代码中做额外的处理。
- 除非是在方法的栈中创造大量的ThreadLocal局部变量进行值的设置,否则这些回收相关代码几乎毫无意义。(会有这么傻的人吗)
- ThreadLocal基本是一个可以安心使用的方法。(除非用的人脑子太清奇,本来即使在局部方法内创建ThreadLocal,调用了的get,set方法也会进行部分垃圾回收,所以就是有这么做的人,也顶天就泄露一点点内存,唯一问题在于快进行垃圾回收之前大量调用此局部方法,之后再也不使用ThreadLocal,则会造成比较严重的内存泄露)。