目录
一、简介
1.1 ThreadLocal是什么
参考JDK的原生注释:
- ThreadLocal用于提供线程粒度的局部变量,线程在访问ThreadLocal实例的时候(通过其get或set方法)有自己的、独立初始化的变量副本。
- ThreadLocal实例通常是类中的私有静态字段,使用它的目的是希望将一些状态信息(例如用户ID或事务ID)与线程关联起来。
总结来说,ThreadLocal的作用就是让每一个线程绑定自己的值,线程间不互相影响,从而避免了线程安全问题。这与我们熟知的通过锁的方式(例如synchronized、Lock等)来保证线程安全是两种不同的思路:
- 锁的方式是多个线程抢占一个共享资源,同步机制是以时间换空间,执行时间不同,类似于排队
- ThreadLocal则是每个线程都有独立的资源副本,同步机制是以空间换时间,同时执行且互不干扰
1.2 核心API
- static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier):创建一个ThreadLocal实例,并通过传入一个Supplier接口定义初始值
- T get():返回ThreadLocal中当前线程的副本值。如果没有当前线程的值,则首先将其赋为初始值,并返回。
- void set(T value):将ThreadLocal中当前线程的副本赋为指定的值
- void remove():将ThreadLocal中当前线程的副本值移除
二、ThreadLocal的应用场景
2.1 从一个小的demo说起
-
需求:5个销售卖房子,统计销售总数
public class ThreadLocalDemo { public static void main(String[] args) { House house = new House(); for(int i = 1;i <= 5;i++){ new Thread(()->{ int size = new Random().nextInt(5) + 1; System.out.println(Thread.currentThread().getName() + "\t"+"curr count:" + size); for(int j = 1;j <= size;j ++){ house.saleHouse(); } },"Thread" + i).start(); } try { TimeUnit.MILLISECONDS.sleep(300); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + "\t"+"total count:"+ house.saleCount); } } class House{ int saleCount = 0; public synchronized void saleHouse(){ //通过加锁保证共享资源的线程安全 ++saleCount; } }
运行结果:
-
需求发生变化:统计五个销售每个人的销售数量
public class ThreadLocalDemo { public static void main(String[] args) { House house = new House(); for(int i = 1;i <= 5;i++){ new Thread(()->{ try { int size = new Random().nextInt(5) + 1; System.out.println(Thread.currentThread().getName() + "\t"+"expect count:" + size); for(int j = 1;j <= size;j ++){ house.saleHouse(); } System.out.println(Thread.currentThread().getName() + "\t"+"real count:" + house.getCount()); } finally { house.remove(); //回收ThreadLocal中当前现成的副本值 } },"Thread" + i).start(); } } } class House{ private static final ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(()->0); public void saleHouse(){ threadLocal.set(threadLocal.get() + 1); } public int getCount() { return threadLocal.get(); } public void remove() { threadLocal.remove(); } }
阿里开发规范:自定义的ThreadLocal变量需要进行回收,尤其在线程池场景下(线程经常会被复用),如果不回收,可能会影响后序业务逻辑,也有可能造成内存泄露。尽量使用try-finally块进行回收。
运行结果:
2.2 非线程安全的SimpleDateFormat
开发过程中,会高频使用SimpleDateFormat处理时间信息,且一般会定义一个静态的SimpleDateFormat变量,殊不知此种写法在多线程环境下的危险性!
下面这段代码在运行时会出现什么意外状况?
public class DateUtils {
public static final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
public static Date parseDate(String stringDate) throws Exception {
return sdf.parse(stringDate);
}
public static void main(String[] args) throws Exception {
for (int i = 1; i <= 30; i++) {
//模拟并发场景下通过SimpleDateFormat处理时间的场景
new Thread(() -> {
try {
System.out.println(DateUtils.parseDate("2020-11-11 11:11:11"));
} catch (Exception e) {
e.printStackTrace();
}
}, "Thread" + i).start();
}
}
}
- 阅读SimpleDateFormat源码可以看出,SimpleDateFormat中的操作都没有加锁,这意味着静态的SimpleDateFormat变量(多个线程共享)在高并发场景中会存在线程安全问题
想要解决上述问题,我们能想到两种直接的方案:
- SimpleDateFormat的处理逻辑加锁,这会导致执行效率比较低
- SimpleDateFormat定义为局部变量,这种大量创建SimpleDateFormat对象的方式很不优雅
这时,借助ThreadLocal便可以即高效又优雅地解决SimpleDateFormat线程安全问题。
public class DateUtils {
public static final ThreadLocal<SimpleDateFormat> sdfThreadLocal =
ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
public static Date parseByThreadLocal(String stringDate) throws ParseException {
return sdfThreadLocal.get().parse(stringDate);
}
public static void main(String[] args) throws Exception {
for (int i = 1; i <= 30; i++) {
//模拟并发场景下通过SimpleDateFormat处理时间的场景
new Thread(() -> {
try {
System.out.println(parseByThreadLocal("2022-12-28 11:20:30"));
} catch (Exception e) {
e.printStackTrace();
}
}, "Thread" + i).start();
}
}
}
这时运行代码,不会由于多线程并发而报错
2.3 Springboot过滤器+ThreadLocal保存请求参数
请求Java后端的API服务时,http请求中经常会携带一些通用参数,比如token、uid、did这些,以token为例,我们看下如何在过滤器中利用ThreadLocal保存token。
-
过滤器代码
public class CommonParamFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { String token = HttpUtils.getCommonParamFromRequest(request, HTTP_PARAM_TOKEN); CommonThreadLocal.setToken(token); //解析token,存到ThreadLocal中 filterChain.doFilter(request, response); } }
-
工具类
public class HttpUtils { //从请求头或者请求参数中解析出具体的参数值 public static String getCommonParamFromRequest(HttpServletRequest request, String key) { return firstNonBlank(request.getHeader(key), request.getParameter(key)); } }
public class CommonThreadLocal { private static final ThreadLocal<String> TOKEN = new ThreadLocal<>(); @Nonnull public static String token() { return firstNonNull(TOKEN.get(), StringUtils.EMPTY); } public static void setToken(String token) { if (token != null) { TOKEN.set(token); } } public static void removeToken() { TOKEN.remove(); } }
-
拦截器中校验token
public class TokenCheckInterceptor extends HandlerInterceptorAdapter { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler){ CommonAssert.assertTrue(isNotBlank(token()), CommonCode.INVALID_PARAM); return true; } @Override public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception { removeToken(); }
public class CommonAssert { public static void assertTrue(boolean expression, CommonCode code) { if (!expression) { throw new XxxBaseException(code); } } }
说明和总结:
- Java的API服务收到http请求后,通常是需要Servlet容器(比如Tomcat)中为其调度线程(通常是线程池的方式)处理请求
- 在过滤器中,我们可以从http请求中解析出token等信息,并利用ThreadLocal进行存储,这样在高并发的场景下,每个处理线程都能存储自己的token副本,且互不影响
- 在请求的整个生命周期中,都可以通过CommonThreadLocal中的token()方法获取当前请求携带的token(例如在拦截器中进行token校验等)
- 在请求退出之前(例如在拦截器的afterCompletion()中),需要对当前线程的token副本进行清理,避免线程复用时产生问题
三、ThreadLocal的实现原理
ThreadLocal的底层实现中涉及到的类和接口的关系如下:
- Thread类中的实例变量threadLocals,其类型为ThreadLocalMap,是ThreadLocal中的内部类
- ThreadLocalMap中的数据存在了Entry数组里,通过k-v的方式存储了当前线程所有的ThreadLocal副本,其中k为ThreadLocal变量,v为副本值
- 正因为每个线程中都有自己的ThreadLocalMap,因此每个线程中都有自己的ThreadLocal变量副本
- Entry继承了WeakReference,Entry中的key是通过弱引用的方式指向了ThreadLocal变量
为了更直观的理解ThreadLocal的底层实现,我们假设有两个ThreadLocal变量:
private static final ThreadLocal<Car> TOKEN = new ThreadLocal<>();
private static final ThreadLocal<Bike> BIKE = new ThreadLocal<>();
如果有两个线程T1和T2均分别设置了TOKEN和BIKE对应的副本值,此时的内存简图如下:
3.1 源码分析
3.1.1 ThreadLocal#set
线程T调用ThreadLocal的set方法进行赋值,实际上就是构造当前ThreadLocal实例为key,值为value的Entry,往线程T的ThreadLocalMap中存放
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t); //当前线程的ThreadLocalMap
if (map != null)
// 存在则调用map.set设置此实体entry
map.set(this, value);
else
createMap(t, value); //不存在ThreadLocalMap则进行初始化
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
//t-当前线程 firstValue-副本值 this-当前ThreadLocal对象
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
//初始化table,默认容量为16
table = new ThreadLocal.ThreadLocalMap.Entry[INITIAL_CAPACITY];
//根据hash值计算索引
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new ThreadLocal.ThreadLocalMap.Entry(firstKey, firstValue);
size = 1;
//设置扩容触发阈值(容量的2/3)
setThreshold(INITIAL_CAPACITY);
}
set方法中主要由两个逻辑分支:
-
当前线程中不存在ThreadLocalMap,此时需要进行初始化逻辑
- 设置初始容量
- 以当前ThreadLocal实例和对应的值value构造第一个entry
- 基于ThreadLocal实例对应的hash值,计算索引,将第一个entry放到Entry数组的具体位置上
-
当前线程中已经存在ThreadLocalMap,此时调用ThreadLocalMap的set方法进行赋值
private void set(ThreadLocal<?> key, Object value) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); //计算索引 for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); if (k == key) { //新值覆盖旧值 e.value = value; return; } if (k == null) { //脏Entry replaceStaleEntry(key, value, i); return; } } tab[i] = new Entry(key, value); //构造新的entry放到具体位置上 int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); }
若Entry中的key指向的ThreadLocal对象为null(后文会剖析为什么会出现key为null的情况),那么Entry中的value将无法访问到,这样的Entry称为脏Entry
ThreadLocalMap的set方法中,计算出当前ThreadLocal实例存放的索引位置i后,有四种场景:
-
位置i的Entry为null,即未被占用,则直接构造新的Entry放到位置i上
-
位置i的Entry不为null,且Entry的key正是当前的ThreadLocal实例,则新值覆盖旧值
-
位置i的Entry不为nul,且Entry的key为null,说明是脏数据,则调用replaceStaleEntry方法进行脏数据清理,并将当前k-v对应的Entry放到合适的位置上
-
位置i的Entry不为null,key不是null并且与当前ThreadLocal实例不同,此时说明发生了hash冲突,则需要调用nextIndex 方法为新的Entry找到一个新的位置
private static int nextIndex(int i, int len) { return ((i + 1 < len) ? i + 1 : 0); //下一个位置,如果到尾部从头开始 }
-
ThreadLocalMap的set方法最后,进行了扩容逻辑的处理,当cleanSomeSlots返回false(即没有找到脏Entry进行清理)并且当前size达到阈值,则会调用rehash()方法进行扩容。
-
cleanSomeSlots方法的主要作用是寻找脏Entry(即key=null的Entry),然后进行清理
private boolean cleanSomeSlots(int i, int n) { boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { //找到脏Entry n = len; removed = true; i = expungeStaleEntry(i); //清理脏Entry } } while ( (n >>>= 1) != 0); return removed; }
- while ( (n >>>= 1) != 0)说明要循环log2(n)次
- 没有发现脏Entry,会一直往后检查下一个位置的Entry;如果发现了脏Entry,则会重置n,重新循环log2(n) 次
-
expungeStaleEntry方法的主要作用是清理指定位置的脏Entry,并且会向后遍历(遍历到Entry为null终止)碰到脏Entry即进行清理
private int expungeStaleEntry(int staleSlot) { Entry[] tab = table; int len = tab.length; //清理staleSlot位置上的脏Entry tab[staleSlot].value = null; tab[staleSlot] = null; size--; Entry e; int i; for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { //向后遍历,到Entry为null终止 ThreadLocal<?> k = e.get(); if (k == null) { //清理脏entry e.value = null; tab[i] = null; size--; } else { int h = k.threadLocalHashCode & (len - 1); if (h != i) { //说明存放k的时候发生了hash冲突,这时会尝试将k对应的entry放到hash值计算出来的索引位置,如果有冲突则尝试下一个位置 tab[i] = null; while (tab[h] != null) h = nextIndex(h, len); tab[h] = e; } } } return i; }
-
rehash方法会先清理脏Entry,然后进行扩容的容量判断,容量达到扩容值,则进行扩容
private void rehash() { expungeStaleEntries(); if (size >= threshold - threshold / 4) resize(); }
3.1.2 ThreadLocal#get
线程T调用ThreadLocal的get方法取值,实际上是以当前ThreadLocal实例为key,从线程T的ThreadLocalMap中找出对应的value值
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t); //当前线程的ThreadLocalMap
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this); //从ThreadLocalMap中取出Entry
if (e != null) {
T result = (T)e.value; //从Entry中取出value
return result;
}
}
return setInitialValue(); //当前线程还未设置副本值,此时设置初始值并返回
}
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value); //当前线程存在ThreadLocalMap,直接赋值
else
createMap(t, value); //当前线程不存在ThreadLocalMap,创建并赋值
return value;
}
下面我们剖析ThreadLocalMap的getEntry方法:
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1); //根据hash值计算目标Entry的索引值
Entry e = table[i];
if (e != null && e.get() == key) //找到目标Entry,直接返回
return e;
else
return getEntryAfterMiss(key, i, e); //目标Entry不在预期的位置,遍历后面索引位置查找元素
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key) //找到目标Entry,返回
return e;
if (k == null)
expungeStaleEntry(i); //清理脏Entry
else
i = nextIndex(i, len); //下一个索引值
e = tab[i];
}
return null;
}
- ThreadLocalMap的getEntry方法通过当前ThreadLocal的hash值计算出目标Entry的预期索引位置
- 如果预期索引位置上的Entry就是目标Entry,直接返回;如果目标Entry不在预期的位置(比如set的时候发生了hash冲突),则需要遍历ThreadLocalMap中的Entry数组进行搜索
- 搜索过程中,一旦发现脏Entry,会立即进行清理
3.1.3 ThreadLocal#remove
线程T调用ThreadLocal的remove方法,实际上是清除线程T的ThreadLocalMap中以当前ThreadLocal实例为key的Entry对象
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread()); //当前线程的ThreadLocalMap
if (m != null)
m.remove(this);
}
下面我们剖析下ThreadLocalMap的remove方法:
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1); //根据hash值计算目标Entry的索引值
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) { //找到了目标Entry
e.clear(); //调用Reference的clear方法,断开指向ThreadLocal对象的弱引用
expungeStaleEntry(i); //进行清理
return;
}
}
}
//Reference的clear方法
public void clear() {
this.referent = null;
}
3.2 ThreadLocal中的内存泄漏问题
3.2.1 Entry中为什么要用弱引用?
对于只有弱引用的对象来说,在JVM的下一次GC中,不管JVM的内存空间是否足够,都会回收该对象
我们考虑这样一段代码:
public void test() {
ThreadLocal<String> TL = new ThreadLocal<>();
TL.set(test);
System.out.println(TL.get());
}
- 当test方法执行完毕后,Java虚拟机栈中指向Java堆中ThreadLocal对象的强引用 TL被销毁,ThreadLocal对象不会在程序中继续使用,预期是能被垃圾回收机制清理掉
- 但此时线程的ThreadLocalMap里某个Entry的key引用还指向这个ThreadLocal对象(在线程池的场景下,线程对象极有可能一直不被销毁),如果这个key引用是强引用,这个不会在程序中继续使用的ThreadLocal对象就不能被GC回收,造成内存泄漏
- 如果通过弱引用指向ThreadLocal对象,在方法执行完毕后,ThreadLocal对象只存在弱引用,就可以顺利被GC回收
3.2.2 为什么要清理脏Entry
由于Entry中的key是通过弱引用指向ThreadLocal对象,那么当ThreadLocal对象被GC回收后,就产生了脏Entry。
- 上文我们已经分析了ThreadLocal的get方法,对于脏Entry中来说,其key为null,这意味着脏Entry中的value将无法访问
- 如果线程一直不被销毁(比如线程池场景中),那么Thread->ThreadLocalMap->Entry->value这样一条强引用链的存在将导致value对象不会被GC回收,造成内存泄漏
现在回头看下我们在 “ThreadLocal的应用场景”一节中的代码,使用ThreadLocal之后都会手动调用remove方法清理掉对应的Entry,就是为了防止脏Entry导致的内存泄漏问题。
再看ThreadLocal的get和set方法,都实现了清理脏Entry的逻辑,这也是为了防止内存泄露问题。
四、总结
- ThreadLocal并不解决线程间共享数据的问题,其适用于变量在线程间隔离且在方法间共享的场景
- 每个线程持有一个自己的专属Map并维护了ThreadLocal对象与具体实例的映射,该Map只能被持有它的线程访问,故不存在线程安全以及锁的问题
- ThreadLocalMap的Entry对ThreadLocal的引用为弱引用,避免了ThreadLocal对象无法被回收的问题
- ThreadLocal在get和set方法中都实现了清理脏Entry的逻辑,防止内存泄露问题的发生
- 使用ThreadLocal之后需要手动调用remove方法清理掉对应的Entry