解读ThreadLocal的应用场景和实现原理

一、简介

1.1 ThreadLocal是什么

参考JDK的原生注释:

在这里插入图片描述

  • ThreadLocal用于提供线程粒度的局部变量,线程在访问ThreadLocal实例的时候(通过其get或set方法)有自己的、独立初始化的变量副本
  • ThreadLocal实例通常是类中的私有静态字段,使用它的目的是希望将一些状态信息(例如用户ID或事务ID)与线程关联起来。

总结来说,ThreadLocal的作用就是让每一个线程绑定自己的值,线程间不互相影响,从而避免了线程安全问题。这与我们熟知的通过锁的方式(例如synchronized、Lock等)来保证线程安全是两种不同的思路:

  • 锁的方式是多个线程抢占一个共享资源,同步机制是以时间换空间,执行时间不同,类似于排队
  • ThreadLocal则是每个线程都有独立的资源副本,同步机制是以空间换时间,同时执行且互不干扰

1.2 核心API

  • static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier):创建一个ThreadLocal实例,并通过传入一个Supplier接口定义初始值
  • T get():返回ThreadLocal中当前线程的副本值。如果没有当前线程的值,则首先将其赋为初始值,并返回。
  • void set​(T value):将ThreadLocal中当前线程的副本赋为指定的值
  • void remove​():将ThreadLocal中当前线程的副本值移除

二、ThreadLocal的应用场景

2.1 从一个小的demo说起

  • 需求:5个销售卖房子,统计销售总数

    public class ThreadLocalDemo {
    
        public static void main(String[] args) {
            House house = new House();
            for(int i = 1;i <= 5;i++){
                new Thread(()->{
                    int size = new Random().nextInt(5) + 1;
                    System.out.println(Thread.currentThread().getName() + "\t"+"curr count:" + size);
                    for(int j = 1;j <= size;j ++){
                        house.saleHouse();
                    }
                },"Thread" + i).start();
            }
    
            try {
                TimeUnit.MILLISECONDS.sleep(300);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + "\t"+"total count:"+ house.saleCount);
        }
    }
    
    class House{
        int saleCount = 0;
        public synchronized void saleHouse(){ //通过加锁保证共享资源的线程安全
            ++saleCount;
        }
    }
    

    运行结果:

    在这里插入图片描述

  • 需求发生变化:统计五个销售每个人的销售数量

    public class ThreadLocalDemo {
    
        public static void main(String[] args) {
            House house = new House();
            for(int i = 1;i <= 5;i++){
                new Thread(()->{
                    try {
                        int size = new Random().nextInt(5) + 1;
                        System.out.println(Thread.currentThread().getName() + "\t"+"expect count:" + size);
                        for(int j = 1;j <= size;j ++){
                            house.saleHouse();
                        }
                        System.out.println(Thread.currentThread().getName() + "\t"+"real count:" + house.getCount());
                    } finally {
                        house.remove(); //回收ThreadLocal中当前现成的副本值
                    }
                },"Thread" + i).start();
            }
        }
    }
    
    class House{
        private static final ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(()->0);
    
        public void saleHouse(){
            threadLocal.set(threadLocal.get() + 1);
        }
    
        public int getCount() {
            return threadLocal.get();
        }
    
        public void remove() {
            threadLocal.remove();
        }
    }
    

    阿里开发规范:自定义的ThreadLocal变量需要进行回收,尤其在线程池场景下(线程经常会被复用),如果不回收,可能会影响后序业务逻辑,也有可能造成内存泄露。尽量使用try-finally块进行回收。

    运行结果:

    在这里插入图片描述

2.2 非线程安全的SimpleDateFormat

开发过程中,会高频使用SimpleDateFormat处理时间信息,且一般会定义一个静态的SimpleDateFormat变量,殊不知此种写法在多线程环境下的危险性

下面这段代码在运行时会出现什么意外状况?

public class DateUtils {

    public static final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");

    public static Date parseDate(String stringDate) throws Exception {
        return sdf.parse(stringDate);
    }

    public static void main(String[] args) throws Exception {
        for (int i = 1; i <= 30; i++) {
            //模拟并发场景下通过SimpleDateFormat处理时间的场景
            new Thread(() -> {
                try {
                    System.out.println(DateUtils.parseDate("2020-11-11 11:11:11"));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }, "Thread" + i).start();
        }
    }
}

在这里插入图片描述

  • 阅读SimpleDateFormat源码可以看出,SimpleDateFormat中的操作都没有加锁,这意味着静态的SimpleDateFormat变量(多个线程共享)在高并发场景中会存在线程安全问题

想要解决上述问题,我们能想到两种直接的方案:

  • SimpleDateFormat的处理逻辑加锁,这会导致执行效率比较低
    在这里插入图片描述
  • SimpleDateFormat定义为局部变量,这种大量创建SimpleDateFormat对象的方式很不优雅
    在这里插入图片描述

这时,借助ThreadLocal便可以即高效又优雅地解决SimpleDateFormat线程安全问题。

public class DateUtils {

    public static final ThreadLocal<SimpleDateFormat> sdfThreadLocal =
            ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
            
    public static Date parseByThreadLocal(String stringDate) throws ParseException {
        return sdfThreadLocal.get().parse(stringDate);
    }

    public static void main(String[] args) throws Exception {
        for (int i = 1; i <= 30; i++) {
            //模拟并发场景下通过SimpleDateFormat处理时间的场景
            new Thread(() -> {
                try {
                    System.out.println(parseByThreadLocal("2022-12-28 11:20:30"));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }, "Thread" + i).start();
        }
    }
}

这时运行代码,不会由于多线程并发而报错

在这里插入图片描述

2.3 Springboot过滤器+ThreadLocal保存请求参数

请求Java后端的API服务时,http请求中经常会携带一些通用参数,比如token、uid、did这些,以token为例,我们看下如何在过滤器中利用ThreadLocal保存token。

  • 过滤器代码

    public class CommonParamFilter extends OncePerRequestFilter {
    
        @Override
        protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
                throws ServletException, IOException {
            String token = HttpUtils.getCommonParamFromRequest(request, HTTP_PARAM_TOKEN);
            CommonThreadLocal.setToken(token); //解析token,存到ThreadLocal中
            filterChain.doFilter(request, response);
        }
    }
    
  • 工具类

    public class HttpUtils {
    
    	//从请求头或者请求参数中解析出具体的参数值
        public static String getCommonParamFromRequest(HttpServletRequest request, String key) {
            return firstNonBlank(request.getHeader(key), request.getParameter(key));
        }
    }
    
    public class CommonThreadLocal {
    
        private static final ThreadLocal<String> TOKEN = new ThreadLocal<>();
    
        @Nonnull
        public static String token() {
            return firstNonNull(TOKEN.get(), StringUtils.EMPTY);
        }
    
        public static void setToken(String token) {
            if (token != null) {
                TOKEN.set(token);
            }
        }
    
    	public static void removeToken() {
        	TOKEN.remove();
    	}
    }
    
  • 拦截器中校验token

    public class TokenCheckInterceptor extends HandlerInterceptorAdapter {
    
        @Override
        public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler){
            CommonAssert.assertTrue(isNotBlank(token()), CommonCode.INVALID_PARAM);
            return true;
        }
    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
        removeToken();
    }
    
    public class CommonAssert {
    
        public static void assertTrue(boolean expression, CommonCode code) {
            if (!expression) {
                throw new XxxBaseException(code);
            }
        }
    }
    

说明和总结:

  • Java的API服务收到http请求后,通常是需要Servlet容器(比如Tomcat)中为其调度线程(通常是线程池的方式)处理请求
  • 在过滤器中,我们可以从http请求中解析出token等信息,并利用ThreadLocal进行存储,这样在高并发的场景下,每个处理线程都能存储自己的token副本,且互不影响
  • 在请求的整个生命周期中,都可以通过CommonThreadLocal中的token()方法获取当前请求携带的token(例如在拦截器中进行token校验等)
  • 在请求退出之前(例如在拦截器的afterCompletion()中),需要对当前线程的token副本进行清理,避免线程复用时产生问题

三、ThreadLocal的实现原理

ThreadLocal的底层实现中涉及到的类和接口的关系如下:

在这里插入图片描述

  • Thread类中的实例变量threadLocals,其类型为ThreadLocalMap,是ThreadLocal中的内部类
  • ThreadLocalMap中的数据存在了Entry数组里,通过k-v的方式存储了当前线程所有的ThreadLocal副本,其中k为ThreadLocal变量,v为副本值
  • 正因为每个线程中都有自己的ThreadLocalMap,因此每个线程中都有自己的ThreadLocal变量副本
  • Entry继承了WeakReference,Entry中的key是通过弱引用的方式指向了ThreadLocal变量

为了更直观的理解ThreadLocal的底层实现,我们假设有两个ThreadLocal变量:

private static final ThreadLocal<Car> TOKEN = new ThreadLocal<>();
private static final ThreadLocal<Bike> BIKE = new ThreadLocal<>();

如果有两个线程T1和T2均分别设置了TOKEN和BIKE对应的副本值,此时的内存简图如下:

在这里插入图片描述

3.1 源码分析

3.1.1 ThreadLocal#set

线程T调用ThreadLocal的set方法进行赋值,实际上就是构造当前ThreadLocal实例为key,值为value的Entry,往线程T的ThreadLocalMap中存放

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t); //当前线程的ThreadLocalMap
    if (map != null)
        // 存在则调用map.set设置此实体entry
        map.set(this, value);
    else
        createMap(t, value); //不存在ThreadLocalMap则进行初始化
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    //t-当前线程 firstValue-副本值 this-当前ThreadLocal对象
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    //初始化table,默认容量为16
    table = new ThreadLocal.ThreadLocalMap.Entry[INITIAL_CAPACITY];
    //根据hash值计算索引
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new ThreadLocal.ThreadLocalMap.Entry(firstKey, firstValue);
    size = 1;
    //设置扩容触发阈值(容量的2/3)
    setThreshold(INITIAL_CAPACITY);
}

set方法中主要由两个逻辑分支:

  • 当前线程中不存在ThreadLocalMap,此时需要进行初始化逻辑

    • 设置初始容量
    • 以当前ThreadLocal实例和对应的值value构造第一个entry
    • 基于ThreadLocal实例对应的hash值,计算索引,将第一个entry放到Entry数组的具体位置上
  • 当前线程中已经存在ThreadLocalMap,此时调用ThreadLocalMap的set方法进行赋值

    private void set(ThreadLocal<?> key, Object value) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1); //计算索引
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();
    
            if (k == key) { //新值覆盖旧值
                e.value = value;
                return;
            }
    
            if (k == null) { //脏Entry
                replaceStaleEntry(key, value, i);
                return;
            }
        }
        tab[i] = new Entry(key, value); //构造新的entry放到具体位置上
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }
    

    若Entry中的key指向的ThreadLocal对象为null(后文会剖析为什么会出现key为null的情况),那么Entry中的value将无法访问到,这样的Entry称为脏Entry

    ThreadLocalMap的set方法中,计算出当前ThreadLocal实例存放的索引位置i后,有四种场景:

    • 位置i的Entry为null,即未被占用,则直接构造新的Entry放到位置i上

    • 位置i的Entry不为null,且Entry的key正是当前的ThreadLocal实例,则新值覆盖旧值

    • 位置i的Entry不为nul,且Entry的key为null,说明是脏数据,则调用replaceStaleEntry方法进行脏数据清理,并将当前k-v对应的Entry放到合适的位置上

    • 位置i的Entry不为null,key不是null并且与当前ThreadLocal实例不同,此时说明发生了hash冲突,则需要调用nextIndex 方法为新的Entry找到一个新的位置

      private static int nextIndex(int i, int len) {
          return ((i + 1 < len) ? i + 1 : 0); //下一个位置,如果到尾部从头开始
      }
      

ThreadLocalMap的set方法最后,进行了扩容逻辑的处理,当cleanSomeSlots返回false(即没有找到脏Entry进行清理)并且当前size达到阈值,则会调用rehash()方法进行扩容。

  • cleanSomeSlots方法的主要作用是寻找脏Entry(即key=null的Entry),然后进行清理

    private boolean cleanSomeSlots(int i, int n) {
        boolean removed = false;
        Entry[] tab = table;
        int len = tab.length;
        do {
            i = nextIndex(i, len);
            Entry e = tab[i];
            if (e != null && e.get() == null) { //找到脏Entry
                n = len;
                removed = true;
                i = expungeStaleEntry(i); //清理脏Entry
            }
        } while ( (n >>>= 1) != 0);
        return removed;
    }
    
    • while ( (n >>>= 1) != 0)说明要循环log2(n)次
    • 没有发现脏Entry,会一直往后检查下一个位置的Entry;如果发现了脏Entry,则会重置n,重新循环log2(n) 次
  • expungeStaleEntry方法的主要作用是清理指定位置的脏Entry,并且会向后遍历(遍历到Entry为null终止)碰到脏Entry即进行清理

    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        //清理staleSlot位置上的脏Entry
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;
        
        Entry e;
        int i;
        for (i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) { //向后遍历,到Entry为null终止
            ThreadLocal<?> k = e.get();
            if (k == null) { //清理脏entry
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                int h = k.threadLocalHashCode & (len - 1);
                if (h != i) { //说明存放k的时候发生了hash冲突,这时会尝试将k对应的entry放到hash值计算出来的索引位置,如果有冲突则尝试下一个位置
                    tab[i] = null;
                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    tab[h] = e;
                }
            }
        }
        return i;
    }
    
  • rehash方法会先清理脏Entry,然后进行扩容的容量判断,容量达到扩容值,则进行扩容

    private void rehash() {
        expungeStaleEntries();
    
        if (size >= threshold - threshold / 4)
            resize();
    }
    

3.1.2 ThreadLocal#get

线程T调用ThreadLocal的get方法取值,实际上是以当前ThreadLocal实例为key,从线程T的ThreadLocalMap中找出对应的value值

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t); //当前线程的ThreadLocalMap
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this); //从ThreadLocalMap中取出Entry
        if (e != null) {
            T result = (T)e.value; //从Entry中取出value
            return result;
        }
    }
    return setInitialValue(); //当前线程还未设置副本值,此时设置初始值并返回
}

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value); //当前线程存在ThreadLocalMap,直接赋值
    else
        createMap(t, value); //当前线程不存在ThreadLocalMap,创建并赋值
    return value;
}

下面我们剖析ThreadLocalMap的getEntry方法:

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1); //根据hash值计算目标Entry的索引值
    Entry e = table[i];
    if (e != null && e.get() == key) //找到目标Entry,直接返回
        return e;
    else
        return getEntryAfterMiss(key, i, e); //目标Entry不在预期的位置,遍历后面索引位置查找元素
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key) //找到目标Entry,返回
            return e;
        if (k == null)
            expungeStaleEntry(i); //清理脏Entry
        else
            i = nextIndex(i, len); //下一个索引值
        e = tab[i];
    }
    return null;
}
  • ThreadLocalMap的getEntry方法通过当前ThreadLocal的hash值计算出目标Entry的预期索引位置
  • 如果预期索引位置上的Entry就是目标Entry,直接返回;如果目标Entry不在预期的位置(比如set的时候发生了hash冲突),则需要遍历ThreadLocalMap中的Entry数组进行搜索
  • 搜索过程中,一旦发现脏Entry,会立即进行清理

3.1.3 ThreadLocal#remove

线程T调用ThreadLocal的remove方法,实际上是清除线程T的ThreadLocalMap中以当前ThreadLocal实例为key的Entry对象

public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread()); //当前线程的ThreadLocalMap
     if (m != null)
         m.remove(this);
}

下面我们剖析下ThreadLocalMap的remove方法:

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1); //根据hash值计算目标Entry的索引值
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) { //找到了目标Entry
            e.clear(); //调用Reference的clear方法,断开指向ThreadLocal对象的弱引用
            expungeStaleEntry(i); //进行清理
            return;
        }
    }
}
//Reference的clear方法
public void clear() {
     this.referent = null;
}

3.2 ThreadLocal中的内存泄漏问题

3.2.1 Entry中为什么要用弱引用?

对于只有弱引用的对象来说,在JVM的下一次GC中,不管JVM的内存空间是否足够,都会回收该对象

我们考虑这样一段代码:

public void test() {
	ThreadLocal<String> TL  = new ThreadLocal<>();
	TL.set(test);
	System.out.println(TL.get());
}
  • 当test方法执行完毕后,Java虚拟机栈中指向Java堆中ThreadLocal对象的强引用 TL被销毁,ThreadLocal对象不会在程序中继续使用,预期是能被垃圾回收机制清理掉
  • 但此时线程的ThreadLocalMap里某个Entry的key引用还指向这个ThreadLocal对象(在线程池的场景下,线程对象极有可能一直不被销毁),如果这个key引用是强引用,这个不会在程序中继续使用的ThreadLocal对象就不能被GC回收,造成内存泄漏
  • 如果通过弱引用指向ThreadLocal对象,在方法执行完毕后,ThreadLocal对象只存在弱引用,就可以顺利被GC回收

3.2.2 为什么要清理脏Entry

由于Entry中的key是通过弱引用指向ThreadLocal对象,那么当ThreadLocal对象被GC回收后,就产生了脏Entry。

  • 上文我们已经分析了ThreadLocal的get方法,对于脏Entry中来说,其key为null,这意味着脏Entry中的value将无法访问
  • 如果线程一直不被销毁(比如线程池场景中),那么Thread->ThreadLocalMap->Entry->value这样一条强引用链的存在将导致value对象不会被GC回收,造成内存泄漏

现在回头看下我们在 “ThreadLocal的应用场景”一节中的代码,使用ThreadLocal之后都会手动调用remove方法清理掉对应的Entry,就是为了防止脏Entry导致的内存泄漏问题。

再看ThreadLocal的get和set方法,都实现了清理脏Entry的逻辑,这也是为了防止内存泄露问题。

四、总结

  • ThreadLocal并不解决线程间共享数据的问题,其适用于变量在线程间隔离且在方法间共享的场景
  • 每个线程持有一个自己的专属Map并维护了ThreadLocal对象与具体实例的映射,该Map只能被持有它的线程访问,故不存在线程安全以及锁的问题
  • ThreadLocalMap的Entry对ThreadLocal的引用为弱引用,避免了ThreadLocal对象无法被回收的问题
  • ThreadLocal在get和set方法中都实现了清理脏Entry的逻辑,防止内存泄露问题的发生
  • 使用ThreadLocal之后需要手动调用remove方法清理掉对应的Entry
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Tracy_hang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值