线程安全-ThreadLocal
前言
共享资源被多个线程同时访问可能出现不安全的事情。线程安全一直是很重要的事情,没处理好线程安全的问题可能导致错误甚至很难复现排查。
常见的解决办法有定义不可变(immutable)变量:例如Java中的String类型、Guava库里的ImmutableCollections。还有就是对访问共享资源的线程加上锁:例如JUC包下的ReentrantLock和JDK自带的synchronized关键字,或者借助Unsafe类来实现的自旋锁。还有一种是保证所有执行的任务都一个线程里:比如多进程单线程模型部署的verticle(基于EventLoop),ThreadLocal等都是这个原理。
ThreadLocal简介
将每一个线程中的变量存在ThreadLocal里,而这些变量只属于一个线程,因此每次访问变量都是单线程的,于是解决了线程安全的问题。
基本使用
ThreadLocal通常也就是调用他的get,set方法。
public class ThreadLocalTest2 {
private static final ThreadLocal<Integer> tl = new ThreadLocal<>();
public static void main(String[] args) throws InterruptedException {
// 主线程设置值
tl.set(1);
// 主线程获取值
System.out.println("主线程get:" + tl.get());
Runnable task = () -> {
// 子线程获取值
System.out.println("子线程get:" + tl.get());
};
// 开启子线程
new Thread(task).start();
// 等待父子线程都结束
Thread.sleep(10000);
}
}
主线程get:1
子线程get:null
在主线程中往ThreadLocal中set设置了一个值,主线程能得到,子线程却无法得到,说明了ThreadLocal里存放的值只和当前线程绑定。
基本原理
核心分析
看看为啥能这么神奇,往同一个ThreadLocal变量set。结果却是只有set的线程能获取到结果,而别的线程却得不到结果。
当一个线程调用ThreadLocal.set方法的时候
public void set(T value) {
// 获取当前调用set方法的线程t
Thread t = Thread.currentThread();
// 获取当前线程t的map
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
可以看到getMap方法会获取到当前线程t内部的threadLocals这个map变量了。
public class Thread implements Runnable {
ThreadLocal.ThreadLocalMap threadLocals = null;
}
接下来看下这个map是啥
public class ThreadLocal<T> {
static class ThreadLocalMap {
// 这个entry是个弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
// map的key是一个ThreadLocal,value是具体存放在ThreadLocal内部set进去的值
super(k);
value = v;
}
}
// ThreadLocal内的map
private Entry[] table;
}
}
这个map其实就是一个Entry构成的数组,map的key是ThreadLocal,value是具体值的结构。
由此可以看出当一个线程往ThreadLocal里set值的时候,会先找到当前线程的map放入,
所以多个线程不管往ThreadLocal里set多少次,都会对应到所属线程中,保证了线程安全,
本质思是将每个线程会用到的资源都存到当前线程的上下文中去。
其他原理
set()
继续回到set方法,当获取到map不为空时调用map.set(this, value)
private void set(ThreadLocal<?> key, Object value) {
// 获取Entry的数组
Entry[] tab = table;
// 获取数组长度,注意,null也算入统计
int len = tab.length;
/* 计算当前key在map中的位置,threadLocalHashCode是个魔数,减少hash冲突的可能(不写数学原理了),
每次newThreadLocal时候会自增一倍, 不过一般是只会new一次,然后引用一个Map,new ThreadLocal<Map<Object, Object>>()这样。
*/
int i = key.threadLocalHashCode & (len-1);
// 这个循环主要是在数组中找到第一个为null的位置(线性探测法)
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
// nextIndex(i, len)这个方法是个循环遍历数组过程,可以看出Entry[]数组是个循环的(节约空间)
ThreadLocal<?> k = e.get();
// 将WeakReference中引用的ThreadLocal这个key和当前遍历到的比较,如果相同就直接覆盖值
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 将探测到的坐标i,设置为最新的值
tab[i] = new Entry(key, value);
int sz = ++size;
/*
cleanSomeSlots清理过期位置的元素
这里的意思是如果没有清空元素并且当前size大于等于扩容阈值,就要rehash扩容
只要清空过或者小于阈值都不会rehash扩容
*/
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
这就是set方法的主要流程了:通过线性探测法在循环数组中找到合适的位置插入即可。
接下来看看rehash这个方法
private void rehash() {
// 方法一
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis(用低扩容阈值避免延迟)
// 方法二
if (size >= threshold - threshold / 4)
resize();
}
方法一
/**
* Expunge all stale entries in the table.
*/
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
// 对每个WeakReference引用为null的元素进行清理
expungeStaleEntry(j);
}
}
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// expunge entry at staleSlot
// 1.设置Entry[]数组中当前位置的元素(entry)的value为null
tab[staleSlot].value = null;
// 2.设置Entry[]数组中当前位置为null
tab[staleSlot] = null;
/*
3.这里其实应该还要设置WeakReference引用的ThreadLocal为null(Entry.clear()方法),才能彻底断开引用,避免内存泄露, 但是现在调用expungeStaleEntry()是从e != null && e.get() == null这个判断条件进来的,所以
e.get() == null代表着WeakReference的引用为空。
*/
size--;
// Rehash until we encounter null
Entry e;
int i;
// 遍历连续不为null的元素
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
// 通过2次设置null清理垃圾
e.value = null;
tab[i] = null;
size--;
} else {
/*
因为采用线性探测法解决hash冲突,所以这里的位置i有可能不是通过idx=k.threadLocalHashCode & (len - 1)
直接到的,而是将idx加上一定的偏移量得到的。
所以这里要比较h != i成立说明这个元素之前是冲突的,现在将其放到尽可能不冲突的位置,下次调用ThreadLocal.get()
的时候可以减少线性搜索的时间复杂度。
*/
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
方法二
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
// 2倍扩容
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
// 将老的元素一个一个地拷贝到新数组
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
// 这里还是用线性探测法找到新数组中的位置
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
/*
注意:这里设置新的扩容阈值threshold=len * 2 / 3,并且在调用该方法前会判断size >= threshold - threshold / 4
所以实际扩容的阈值是 newLen * (2/3) * (1 - 1/4) = newLen * 0.5
就是超过一半就扩容,印证了作者那句话"Use lower threshold for doubling to avoid hysteresis"
*/
setThreshold(newLen);
size = count;
// 更改table的引用为最新
table = newTab;
}
get()
public T get() {
Thread t = Thread.currentThread();
// 和set方法一样先获取当前线程t的map
ThreadLocalMap map = getMap(t);
if (map != null) {
// 根据key:threadlocal获取map里的值
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 如果当前线程没有map,比如被删除remove()了,重新初始化
return setInitialValue();
}
看一下getEntry方法
private Entry getEntry(ThreadLocal<?> key) {
// 获取Entry[]数组中的坐标i
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 如果该位置有元素,并且该元素是要寻找的key,返回找到的元素
if (e != null && e.get() == key)
return e;
else
// 否则走未命中逻辑
return getEntryAfterMiss(key, i, e);
}
看一下getEntryAfterMiss方法
// (要寻找的key, key计算出来的坐标i, i位置的元素e)
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
/*
从i位置开始从左向右循环找第一个不为null的元素(线性探测法原理,key元素没发生hash冲突前应该放在i位置,如果发生了冲突,
则会将e元素放到i位置右边开始的第一个null的位置上:set方法中的那个循环做的就是这件事),现在实际上是找到原本hash冲突后的偏移位置
*/
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
// 找到返回e
return e;
if (k == null)
// 引用为空,当作元素过期处理,这个方法上面分析过
expungeStaleEntry(i);
else
// 设置新的下标
i = nextIndex(i, len);
// 移动到下一个位置
e = tab[i];
}
// 如果找不到,说明该key从来没存放进来过,返回null
return null;
}
remove()
remove方法挺简单的,主要就是断开map中引用的过程。
public void remove() {
// 和之前一样获取当前线程的map
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
// 调用map的remove方法断开引用
m.remove(this);
}
map的remove方法
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 同样的找到key在Entry[]数组中的位置
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
/* 这里调用e.clear(),因为e是Entry的实例,Entry继承了WeakReference类,所以相当于调用WeakReference#clear
相当于断开
1.WeakReference对ThreadLocal的引用
*/
e.clear();
/*
这个方法里会依次断开:
2. tab[staleSlot].value = null,断开每个Entry中对我们存放的value的引用
3. tab[staleSlot] = null,断开Entry[]数组中对Entry(也就是WeakReference)的引用
*/
expungeStaleEntry(i);
// 断开引用完成,结束
return;
}
}
// 如果跳出循环说明key根本不存在, 直接返回就好了
}
remove方法实际上就是断开1,2,3个步骤中的引用避免内存泄漏。
简单应用
最常见的方式就是使用ThreadLocal存放一个map,这个map用来存放该线程的数据。
可以封装一个Context上下文。
public static class Context {
private static Map<String, Object> infoMap = new ConcurrentHashMap<>();
private static ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal<>();
static {
threadLocal.set(infoMap);
}
public static Object get(String key) {
return threadLocal.get().get(key);
}
public static void set(String key, Object value) {
threadLocal.get().put(key, value);
}
}
也可以使用ThreadLocal#initialValue方法初始化ThreadLocal
// protected方法可以被ThreadLocal子类重写
protected T initialValue() {
return null;
}
这样每次该线程的threalLocalMap被清空后下次get的时候会触发ThreadLocal#initialValue方法进行二次初始化(具体可以看看initialValue被调用的时机)
private static final ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal<Map<String, Object>>() {
// 自定义初始化方法
@Override
protected Map<String, Object> initialValue() {
return new ConcurrentHashMap<>();
}
};
使用看看
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class ThreadLocalTest0 {
public static void main(String[] args) {
Context.set("log1", new ArrayList<>());
Context.set("log2", "a");
Context.set("log3", "b");
System.out.println(Context.get("log1"));
System.out.println(Context.get("log2"));
System.out.println(Context.get("log3"));
}
public static class Context {
private static Map<String, Object> infoMap = new ConcurrentHashMap<>();
private static ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal<>();
static {
threadLocal.set(infoMap);
}
public static Object get(String key) {
return threadLocal.get().get(key);
}
public static void set(String key, Object value) {
threadLocal.get().put(key, value);
}
}
}
输出:
[]
a
b
扩展应用
ThreadLocal只能被一个线程独享,但是应用往往不是单线程的,会涉及到线程上下文的传递,那一个线程在ThreadLocal的数据怎么传递给另一个线程呢?其实就是在子线程中将父线程的数据一个一个拷贝过去。
可能疑问子线程应该获取不到父线程ThreadLocal的数据了吧,怎么实现的呢?其实就是子线程中持有父线程的引用构建类似闭包(closure)的东西就好了
ITL (InheritableThreadLocal)
常见的ThreadLocal子类InheritableThreadLocal他能传递父子间的上下文。
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class InheritableThreadLocalTest implements Runnable {
public static void main(String[] args) throws InterruptedException {
Context.set("main", "主线程");
System.out.println("----主线程获取值:" + Context.get("main"));
// 在主线程中通过 new Thread(...)方法将父线程的上下文拷贝到子线程Thread类的内部
Thread tt = new Thread(new InheritableThreadLocalTest());
tt.start();
Thread.sleep(6000);
System.out.println("----主线程结束");
// 父线程获取子线程再次设置的值:sub=子线程
System.out.println("----主线程获子线程值:" + Context.get("sub"));
}
@Override
public void run() {
// 子线程中直接获取父线程的值main=主线程(ThreadLocal就不能实现)
System.out.println("----子线程获取值:" + Context.get("main"));
// 子线程设置值sub=子线程,给父线程再次获取
Context.set("sub", "子线程");
}
public static class Context {
private static Map<String, Object> infoMap = new ConcurrentHashMap<>();
private static InheritableThreadLocal<Map<String, Object>> threadLocal = new InheritableThreadLocal<>();
static {
threadLocal.set(infoMap);
}
public static Object get(String key) {
return threadLocal.get().get(key);
}
public static void set(String key, Object value) {
threadLocal.get().put(key, value);
}
}
}
输出”
----主线程获取值:主线程
----子线程获取值:主线程
----主线程结束
----主线程获子线程值:子线程
InheritableThreadLocal它将数据存放在每个线程(Thread)内部,当创建一个子线程的时候:new Thread()
会在构造方法中(当前在主线程)将主线程的上下文(ThreadLocal)复制到子线程Thread类中的InheritableThreadLocal去,当子线程获取的时候再从InheritableThreadLocal取出就好了,就实现了跨线程传递数据。
public class Thread implements Runnable {
public Thread(....) {
// 将父线程father的上下文inheritableThreadLocals复制到子线程Thread实例中去
this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
}
}
值得注意的是复制操作只在new Thread(),创建线程的时候进行,一个应用往往使用池化技术减少性能开销,就意味着一个线程一旦被创建出来,就很可能会不断地被复用,一般来说一个线程执行完任务都要清空上下文。当子线程执行完一个任务清空上下文,该线程又会去执行另一个任务这时候由于复用便不会再被创建(new)一次,所以会导致第二次获取不到父线程的上下文,造成数据丢失。
TTL(阿里的TransmittableThreadLocal)
TransmittableThreadLocal(TTL)
是阿里开源的,用于解决异步执行时上下文传递的问题的组件,在InheritableThreadLocal
基础上,实现了线程复用场景下的线程变量传递功能。
就使用了类似构建闭包的原理对Runnable类进行了一个包装:
以下是伪代码
public class MyRunnable implements Runnable{
FatherContext fc;
Runnable runnable;
// 在父线程创建MyRunnable实例的时候设置父线程的fc,同时传入runnable进行包装
public Father(FatherContext fc, Runnable runnable) {
this.fc = fc;
this.runnable = runnable;
}
@Override
public void run() {
// 持有闭包外的fc引用,现在在子线程复制父线程的上下文
copyFatherContext(fc);
runnable.run();
// 记得清空上下文防止上下文污染
clear();
}
}
public class ThreadLocalTest {
Context context = 初始化省略。。。
public static void main(String[] args) throws InterruptedException {
// 父线程设置值father=father-thread
contect.set("father", "father-thread");
Runnable runnable = () -> {
// 子线程获取父线程的值father=father-thread
context.get("father");
// ...
};
new Thread(new MyRunnable(context, runnable)).start();
Thread.sleep(10000);
}
}
总结
ThreadLocal主要运用他的get (放入值),set (取出值)以及不用了使用remove清空避免内存泄漏。
代码中不管get还是set都频繁调用expungeStaleEntry方法进行垃圾回收,同时将map的每个元素都设置为弱引用(WeakRefence:一旦触发JVM垃圾回收,如果WeakRefence内部的对象没有被其他对象强引用,那么该对象会回收),以及采用循环数组来节省空间等操作都可以看出作者设计的时候是比较看重内存空间的消耗。对比其他常见map就不写了,比较懒。
常见的跨线程传递上下文的方法挺多都是直接复制父线程的上下文,简单粗暴,同时定制化需要拷贝的上下文数据。