ThreadLocal使用背景:
ThreadLocal的作用是提供线程内的局部变量,这种变量在线程的生命周期内起作用,减少同一个线程内多个函数或者组件之间一些公共变量的传递的复杂度。
在多个线程对一个对象进行频繁操作,而且此时并不希望看到其他线程的修改痕迹,而是希望每个线程都能有自己的局部变量时,ThreadLocal就充当了对这个变量的包装,使得针对每个线程都是局部变量。
最常见的还是一些工具类对象,比如数据库连接,我们希望每个线程在其工作期间,都使用自己的数据库连接,互相之间不能干扰。如果使用spring默认单例配置,数据库连接除非作为方法的局部变量获取,否则只要是成员变量或者静态变量,那么肯定会出现线程安全问题。
当然可以通过对使用连接的部分代码进行同步操作,但是这显然效率太低,是不可用的。
首先看ThreadLocal用法
public class TestThreadLocal {
// 官方建议static final,这样也可以保证在线程期间,各个方法比较方便的获取这个变量
private static final ThreadLocal<Integer> value = new ThreadLocal<Integer>() {
// 设置初始值,在第一次get时会调用
@Override
protected Integer initialValue() {
return 0;
}
};
public static void main(String[] args) {
for (int i = 0; i < 5; i++) {
new Thread(new MyThread(i)).start();
}
}
static class MyThread implements Runnable {
private int index;
public MyThread(int index) {
this.index = index;
}
public void run() {
System.out.println("线程" + index + "的初始value:" + value.get());
// 5个线程对静态变量value进行等差加法运算,如果没有threadlocal,显然会出现线程安全问题
for (int i = 0; i < 10; i++) {
value.set(value.get() + i);
}
System.out.println("线程" + index + "的累加value:" + value.get());
}
}
}
运行结果都是45,可以看出每个线程像是在操作自己的局部变量value一样,互不干扰
下面通过ThreadLocal的部分源码进行原理分析
// 获取保存在ThreadLocal中的变量
public T get() {
Thread t = Thread.currentThread();
// 通过当前线程对象引用获取ThreadLocalMap
ThreadLocalMap map = getMap(t);
// 如果map不为空,则根据当前ThreadLocal对象的引用获取entry,也即是对应的节点,因为一个线程可能有多个ThreadLocal变量
// 返回entry.value,也就是目标对象
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
// 根据当前运行此代码的线程,获取该线程对象本身持有的ThreadLocalMap对象
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
// ThreadLocalMap是ThreadLocal的静态内部类
static class ThreadLocalMap{
// Entry继承了弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
// 这里值得注意,ThreadLocal对象引用是被 WeakReference<ThreadLocal>(k)
// 弱引用构造方法包装了,也就是说这个ThreadLocalMap的key其实都是ThreadLocal对象
// 的弱引用
super(k);
value = v;
}
}
}
ThreadLocalMap的实例其实保存在线程中,他是线程的成员变量,也就是说,每个线程维护一个属于线程本身的ThreadLocalMap,key是ThreadLocal对象的弱引用,而值就是需要“局部化”的变量。
注明:弱引用的用法
// productA是最常用的,它显然是一个强引用,强引用指向的对象只有在引用被置为空,或者作为局部变量过了其作用域时
// 其指向的对象在堆区才有可能被回收,具体回收得看gc,人为不能控制
Product productA = new Product(...);
// weakProductA是一个指向new Product(...)对象的弱引用,若只有它指向该对象,那么gc也会在某个时刻回收该对象
// 这也是为了避免强引用消失后,hashmap作为缓存一直运行在jvm中,key还指向这个对象,如果key是强引用的话,
// 那么这个对象将始终得不到回收,久而久之,内存就溢出了
WeakReference<Product> weakProductA = new WeakReference<>(productA);
Product product = weakProductA.get();
分析到目前为止,这个弱引用key只是保证了,ThreadLocal对象作为key时,没有强引用指向他时,ThreadLocalMap的key引用不会阻止gc对它的回收,如果发生了gc,就会留下一个key为空的entry,此时显然无法访问value了。
// 当ThreadLocalMap在线程中还没有创建或者value为空时,get会执行设置初始值的方法
private T setInitialValue() {
// initialValue方法通常需要被重写,给定线程持有的初始化变量
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 如果map没有创建,则创建map,如果已经创建,则设置初始值
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
ThreadLocalMap作为一种map,实现方式也就是数据结构跟hashmap是不一样的,hashmap采用数组加链表的方式,用“分离链表法”解决key的hash冲突问题。
而此处的ThreadLocalMap采用“开放定址法”,简单来说就是如果key的hash发生碰撞,那么就尝试获取下一个索引位置,如果不是空元素,就继续,一直到最后都没有空的就返回头部查找,这种查找数组空元素的方式称为“线性探测法”。
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 这里就使用线性探测法来寻找对应的元素,如果hash位置的entry不为空,首先判断key是否相等,
// 如果相等,说明就是同一个元素,把key(弱引用)置为空,
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// clear方法是reference中的,在这里就是把弱引用置为空,expungeStaleEntry(i)
// 清除key为空的陈旧entry
// 如果key不相等,那么就nextIndex,继续判断下一个,如果已经是最后一个,则从0开始
e.clear();
expungeStaleEntry(i);
return;
}
}
}
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
// 首先根据ThreadLocal对象引用计算索引位置
// len因为始终是2的n次方,那么-1的结果就会把低位都变成0,实际上就是等于最大索引内全是1,
// 如果跟一个二进制做&运算,可想而知,留下的是这个二进制低位小于等于len-1的部分
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 放入元素显然也要线性探测
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 如果key相同,则直接替换value
if (k == key) {
e.value = value;
return;
}
// 如果key为空,则用新的key替代陈旧元素
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 代码运行到此处说明找到了数组中的空位置,实际元素数量size+1,清除陈旧元素(如果失败),如果size仍然
// 大于阈值,进行rehash操作,rehash会清除陈旧元素,如果size>3/4*threshold,那么就扩容
// 为2倍
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}