1. 前言
我们知道,ThreadLocal是一个线程内部的数据存储类,通过它可以在指定的线程中存储数据,数据存储以后,只有在指定线程中可以获取到存储的数据,对于其它线程来说无法获取到数据。这么神奇的东西有没有想知道是怎么实现的呢?
2. 代码使用
我比较懒哈,直接把网上一个博客的代码例子拷贝过来演示一下:
mBooleanThreadLocal.set(true);
Log.d(TAG, "[Thread#main]mBooleanThreadLocal=" + mBooleanThreadLocal.get());
new Thread("Thread#1") {
@Override
public void run() {
mBooleanThreadLocal.set(false);
Log.d(TAG, "[Thread#1]mBooleanThreadLocal=" + mBooleanThreadLocal.get());
};
}.start();
new Thread("Thread#2") {
@Override
public void run() {
Log.d(TAG, "[Thread#2]mBooleanThreadLocal=" + mBooleanThreadLocal.get());
};
}.start();
我们可以看出,我们只有一个ThreadLocal变量mBooleanThreadLocal。在主线程中我们调用了set(true),在Thread1子线程中调用了set(false),在Thread2子线程中我们没有调用。大家能猜到结果是什么吗?
D/TestActivity(8676):[Thread#main]mBooleanThreadLocal=true
D/TestActivity(8676):[Thread#1]mBooleanThreadLocal=false
D/TestActivity(8676):[Thread#2]mBooleanThreadLocal=null
按照我们以前的理解,在主线程设置为true,的确应该返回true。在Thread1子线程设置为false,也的确应该返回false,毕竟此时将值改变了嘛。但是在Thread2子线程我们明明什么都没做,为什么返回的结果不应该是在Thread1子线程中被改变成的false,而是返回null呢?这不废话吗?如果返回了false,也就不会有这篇文章了。下面开始源码讲解。
3. ThreadLocal源码解析
我们先从set()方法入手:
public void set(T value) {
Thread currentThread = Thread.currentThread();
Values values = values(currentThread);
if (values == null) {
values = initializeValues(currentThread);
}
values.put(this, value);
}
我们可以大概看出逻辑:先得到当前所在的线程,然后利用当前线程取出来一个Values值,如果这个值为空,则用当前线程初始化,然后调用values的put()方法将当前ThreadLocal作为key,value作为value存储起来。重点在于values(currentThread)方法是怎么回事儿:
Values values(Thread current) {
return current.localValues;
}
非常简单明了,直接返回的当前线程的一个叫做localValues的成员变量。看到这里我们就明白了上面的疑问,因为每个线程都维护了一个成员变量作为数据结构来存储,当然不同线程会呈现不同的结果。然后上面有个初始化方法也非常简单,直接new一个对象避免空指针。
Values initializeValues(Thread current) {
return current.localValues = new Values();
}
接着看get()方法:
public T get() {
// Optimized for the fast path.
Thread currentThread = Thread.currentThread();
Values values = values(currentThread);
if (values != null) {
Object[] table = values.table;
int index = hash & values.mask;
if (this.reference == table[index]) {
return (T) table[index + 1];
}
} else {
values = initializeValues(currentThread);
}
return (T) values.getAfterMiss(this);
}
和set()方法一样,先找到当前Thread的Values变量,然后通过hash和values.mask计算出当前ThreadLocal在Values数据结构中的index,如果table[index]和this.reference相等的话,则表示匹配成功,直接返回table[index+1]数据源。这里有三个变量需要注意一下:hash、values.mask、reference分别是什么。因为values.mask是属于Values的变量,我们暂时只关注另外两个。
/** Weak reference to this thread local instance. */
private final Reference<ThreadLocal<T>> reference
= new WeakReference<ThreadLocal<T>>(this);
/** Hash counter. */
private static AtomicInteger hashCounter = new AtomicInteger(0);
/**
* Internal hash. We deliberately don't bother with #hashCode().
* Hashes must be even. This ensures that the result of
* (hash & (table.length - 1)) points to a key and not a value.
*
* We increment by Doug Lea's Magic Number(TM) (*2 since keys are in
* every other bucket) to help prevent clustering.
*/
private final int hash = hashCounter.getAndAdd(0x61c88647 * 2);
可以看出,reference就是当前ThreadLocal的一个弱引用而已,其作用就是指当前ThreadLocal。回想刚才的table[index]和this.reference判断匹配可以看出,Values就是拿当前的ThreadLocal作为key来存储value的。
然后注意hashCounter这个AtomicInteger,它是static的,也就意味着所有的ThreadLocal共用这一个计数器。hash变量是直接将当前计算器自动增长0x61c88647 * 2这个数,为什么是这个数好像是Doug Lea’s Magic Number(TM)这个东西,有兴趣的同学可以去查阅相关资料。反正结果就是利用当前的hash和values.mask能够计算出当前ThreadLocal在Values的索引。
ThreadLocal的部分只研究到这儿就可以了,我们发现大部分的核心操作还是在Values的table这个Object[]数组的操作上,并且我们能够知道是前一个存储ThreadLocal,后一个储存value的数据结构。接下来我们来看ThreadLocal的内部类Values的源码。
4. ThreadLocal.Values源码解析
因为对Values的读取在ThreadLocal是直接对Object[] table进行读取的,所以我们只需要看储存方法put():
void put(ThreadLocal<?> key, Object value) {
cleanUp();
// Keep track of first tombstone. That's where we want to go back
// and add an entry if necessary.
int firstTombstone = -1;
for (int index = key.hash & mask;; index = next(index)) {
Object k = table[index];
if (k == key.reference) {
// Replace existing entry.
table[index + 1] = value;
return;
}
if (k == null) {
if (firstTombstone == -1) {
// Fill in null slot.
table[index] = key.reference;
table[index + 1] = value;
size++;
return;
}
// Go back and replace first tombstone.
table[firstTombstone] = key.reference;
table[firstTombstone + 1] = value;
tombstones--;
size++;
return;
}
// Remember first tombstone.
if (firstTombstone == -1 && k == TOMBSTONE) {
firstTombstone = index;
}
}
}
可以看出,首先是遍历table数组,如果找到了key.reference与table[index]匹配的话,就返回对应的value。如果发现当前key为null的话,表示这是一个空白区域,此时需要判断firstTombstone是否等于-1,如果等于-1,则表示没有设置需要优先存储的空白区域,此时直接存储到这个区域;如果firstTombstone不等于-1,则需要存储到firstTombstone所对应的空白区域。然后往下走,走到这儿就意味着k不等于null,此时判断这个k是否等于TOMBSTONE,如果等于并且firstTombstone等于-1则设置firstTombstone为当前的索引index。
那么问题来了,这个TOMBSTONE是个什么鬼?翻译出来意思是墓碑……,还是得从其他代码来分析。
/**
* Placeholder for deleted entries.
*/
private static final Object TOMBSTONE = new Object();
直接就是一个最普通的Object而已,注释告诉我们是“已删除实体的占位符”,也就是说,当有ThreadLocal从这个Values中remove()掉后,会用TOMBSTONE作为key,null作为value来占位这块区域。因此在put()的时候,会优先将数据储存到已删除的被TOMBSTONE占位的区域,嗯,很好的思路。我们查看remove()方法来验证这个想法。
void remove(ThreadLocal<?> key) {
cleanUp();
for (int index = key.hash & mask;; index = next(index)) {
Object reference = table[index];
if (reference == key.reference) {
// Success!
table[index] = TOMBSTONE;
table[index + 1] = null;
tombstones++;
size--;
return;
}
if (reference == null) {
// No entry found.
return;
}
}
}
的确没错,所谓的remove操作,就是将对应的区域的key替换为TOMBSTONE,value替换为null。说明我们上面的猜想是完全正确的。最后还遗留了一个地方需要注意,就是在上面的put()方法中的遍历算法:
for (int index = key.hash & mask;; index = next(index)) {
Object k = table[index];
…………
}
index的初始值是key.hash & mask,这个与ThreadLocal的get()方法是一样的,即指明了当前ThreadLocal默认的存储位置,然后循环的条件判断是无限制的,即无限循环,递增条件是next(index)方法,我们来看一下这个方法:
/**
* Gets the next index. If we're at the end of the table, we wrap back
* around to 0.
*/
private int next(int index) {
return (index + 2) & mask;
}
从注释上讲是指得到下一个key的index,并且在遇到table的末尾时,能够自动滚到0再次循环。就这么一行代码就能实现这种功能,果然是牛逼。实现的关键肯定是mask这个变量,我们来看看这是怎么回事儿。搜索发现全局只有下面这一处是对mask的赋值。
private void initializeTable(int capacity) {
this.table = new Object[capacity * 2];
this.mask = table.length - 1;
this.clean = 0;
this.maximumLoad = capacity * 2 / 3; // 2/3
}
Values() {
initializeTable(INITIAL_SIZE);
this.size = 0;
this.tombstones = 0;
}
这个方法是在构造函数中调用的,而且我们可以知道mask的值为2 * INITIAL_SIZE - 1。看到这里依然还是不明不白的,下面是关键的地方
/**
* Size must always be a power of 2.
*/
private static final int INITIAL_SIZE = 16;
如果仅仅看代码而不看注释,估计看多久都不会明白。因此代码的注释是多么的重要的。这个INITIAL_SIZE只能是2的指数。那么也就是说,
INITIAL_SIZE的范围是
1,2,4,8,16……
mask的范围是
1,3,7,15,31
这个又能说明什么问题呢?聪明的小伙伴儿应该已经想到了。mask如果化为二进制的话:
1:1
3:11
7:111
15:1111
31:11111
全是1,然后我们再回过头来看next()方法,可以看出,对比mask小的数,就返回这个数,对比mask大的数,就截断高位数。我们举个例子,如果INITIAL_SIZE等于4,则mask等于7,最后一个key的index为6,此时调用next()方法,index加上2等于8,然后执行&操作:
1000 8
0111 &7
0000 =0
果然就直接滚到0了,然后再从头开始循环。这种操作非常炫酷有木有。
5. 总结
- Thread维护了一个成员变量Values,用于存储不同线程不同的数据结构
- Values内部其实是以对象数组来存储的数据
- Values将ThreadLocal作为key值来存储的,并且key值在前一个索引,value值在后一个索引
- Values的初始容量,必须是2的指数再乘以2,这一点是为了满足mask的值二进制为全1
- 真正数据存储是在ThreadLocal的set()方法执行的,在不同线程调用这个方法,会将不同的数据储存在不同的线程中。
- 调用ThreadLocal的remove()方法时,其实是将对应区域的key设置为默认值TOMBSTONE,value设置为null。TOMBSTONE只是一个普通的Object。