关于SafeThreadLocal的一点思考

关于ThreadLocal,我在博客里已经有过一些记述 参见 https://blog.csdn.net/dlf123321/article/details/43153717

前言

它的作用呢?很简单就是能做到在同一个线程内的不同业务模块内保存一组信息,让各个模块都能修改&读取。
想想如果没有她,只能使用参数传递的方法,从头传到尾,多麻烦。
OK它的作用与好处说完了,那它还有一个很多的问题就是,每次使用完都必须清理。
为什么?我线程结束了,即使不清理,哪能怎么样呢?
是的,你的线程是结束了,但是我们的应用一般都是跑在容器里面的(例如Tomcat),而容器里面的服务线程是从线程池里面取的,它是可以服用的。那如果不清理,造成的影响就是

用户张三放到Threadlocal里面的数据,有可能被后面李四这个用户使用的线程读到!!

这自然是不可接受的。
所以一般使用ThreaLocal的时候都是使用Aop在所有服务的入口处进行初始化,Aop结束的时候,会清理掉ThreadLocal。
理论上是OK的,但是还是有问题,假如我新写的方法不在Aop的监控范围内呢?

即使你知道ThreadLoca的这个坑,但是代码写的多了,总有疏忽的时候,万一忘了咋办?
这是一个很麻烦的问题。

转角遇到SafeThreadLocal

大家看看这个思路的

package org.example.service;

import lombok.extern.slf4j.Slf4j;
import org.example.domain.UserData;
import org.example.util.AutoClean;
import org.example.util.SafeThreadLocal;
import org.springframework.stereotype.Component;

/**
 * @program: parent_pro
 * @description:
 * @author: 渭水
 * @create: 2023/11/03
 */
@Component
@Slf4j
public class ThreadLocalService {

    @AutoClean
    public void serviceA(Long userId) {
        log.info("serviceA first " + userId + " " + SafeThreadLocal.get());

        UserData userData = new UserData();
        userData.setId(15);
        userData.setName("albaba");
        SafeThreadLocal.put(userData);
        serviceB();
        serviceC();

        log.info("serviceA " + SafeThreadLocal.get());
    }

    @AutoClean
    public void serviceB() {

        log.info("serviceB " + SafeThreadLocal.get());
        UserData userData = SafeThreadLocal.get();
        userData.setId(16);
        SafeThreadLocal.put(userData);
    }

    public void serviceC() {
        log.info("serviceC " + SafeThreadLocal.get());
    }

}


上面的是service,下面的是对应的controller

package org.example.controller;

import lombok.extern.slf4j.Slf4j;
import org.example.service.ThreadLocalService;
import org.example.util.SafeThreadLocal;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

/**
 * @program: parent_pro
 * @description:
 * @author: 渭水
 * @create: 2023/11/03
 */
@RestController
@Slf4j
public class ThreadLocalController {

    @Autowired
    private ThreadLocalService threadLocalService;

    @RequestMapping("/threadLocal")
    public String threadLocal() {
        threadLocalService.serviceA(33L);

        try {
            log.info("threadLocal " + SafeThreadLocal.get());
        } catch (Exception e) {
            log.error("threadLocal getUserData error", e);
        }
        return "success";
    }

}

核心的类来了,对应的注解

package org.example.util;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 在方法上添加此注解, 可在范围内使用SafeThreadLocal存取上下文
 * 可多层嵌套使用, 当离开最外层的时候, 会自动清理上下文, 避免内存泄露
 *
 * @author zhuomu
 * @date 2023/8/9
 */
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface AutoClean {
}

下面是切面

package org.example.util;

import java.lang.reflect.Method;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;

/**
 * 添加此注解, 以便在离开方法时执行一些自动清理操作
 * 在实现上, 会使用SafeThreadLocal.call 来执行目标方法, 以便在离开时自动清理相关thread local上下文
 * 详细可见SafeThreadLocal类的说明
 *
 * @author zhuomu
 * @date 2023/8/9
 */
@Aspect
@Component
public class AutoCleanAspect {

    @Around("myAutoClean()")
    public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable {
        // 获得注解
        final MethodSignature methodSignature = (MethodSignature)joinPoint.getSignature();
        final Class<?> targetClass = joinPoint.getTarget().getClass();
        final Method method = targetClass.getDeclaredMethod(methodSignature.getName(),
            methodSignature.getParameterTypes());
        final AutoClean autoClean = method.getAnnotation(AutoClean.class);
        if (autoClean == null) {
            return joinPoint.proceed();
        }

        return SafeThreadLocal.callWithThrow(joinPoint::proceed);
    }

    @SuppressWarnings("unused")
    @Pointcut("@annotation(org.example.util.AutoClean)")
    private void myAutoClean() {
    }
}

最最核心的处理类

package org.example.util;

import org.example.domain.UserData;

/**
 * *
 *
 * @date 2023/8/9
 */
public class SafeThreadLocal {

    static final ThreadLocal<UserData> CONTEXT_THREAD_LOCAL = new ThreadLocal<>();
    /**
     * 嵌套层数, 用来决策是否要清理
     */
    static final ThreadLocal<Integer> NESTED_LEVEL = ThreadLocal.withInitial(() -> 0);

    /**
     * 调用callable方法
     *
     * @param callable 目标方法
     * @param <T>      结果类型
     * @return 目标方法的返回结果
     * @throws Throwable 对底层的异常不会捕获, 会无条件抛出
     */
    public static <T> T callWithThrow(Callable<T> callable) throws Throwable {
        try {
            nestNextLevel();
            return callable.call();
        } finally {
            cleanUp();
        }
    }

    public static UserData get() {
        check();
        return CONTEXT_THREAD_LOCAL.get();
    }

    public static void put(UserData userData) {
        check();
        CONTEXT_THREAD_LOCAL.set(userData);
    }

    /**
     * 调用runnable方法
     *
     * @param runnable 目标方法
     */
    public static void run(Runnable runnable) {
        try {
            nestNextLevel();
            runnable.run();
        } finally {
            cleanUp();
        }
    }

    private static void check() {
        if (isOutOfScope()) {
            throw new IllegalStateException(
                " direct calls to SafeThreadLocal.get is forbidden, "
                    + " please use SafeThreadLocal.callWithThrow or SafeThreadLocal.run ");
        }
    }

    private static void cleanUp() {
        NESTED_LEVEL.set(NESTED_LEVEL.get() - 1);
        if (isOutOfScope()) {
            CONTEXT_THREAD_LOCAL.remove();
            NESTED_LEVEL.remove();
        }
    }

    /**
     * 判断是否是最外层
     *
     * @return 是否是最外层
     */
    private static boolean isOutOfScope() {
        return NESTED_LEVEL.get() <= 0;
    }

    private static void nestNextLevel() {
        NESTED_LEVEL.set(NESTED_LEVEL.get() + 1);
    }

    @FunctionalInterface
    public interface Callable<V> {
        /**
         * Computes a result, or throws an exception if unable to do so.
         *
         * @return computed result
         * @throws Throwable if unable to compute a result
         */
        V call() throws Throwable;
    }
}

调用一次后打印的结果是:

2023-11-06 15:32:13.074 [http-nio-8080-exec-1] INFO  o.e.s.ThreadLocalService - serviceA first 33 null
2023-11-06 15:32:13.075 [http-nio-8080-exec-1] INFO  o.e.s.ThreadLocalService - serviceB UserData(id=15, name=albaba, objectMap=null, otherInfo=null)
2023-11-06 15:32:13.075 [http-nio-8080-exec-1] INFO  o.e.s.ThreadLocalService - serviceC UserData(id=16, name=albaba, objectMap=null, otherInfo=null)
2023-11-06 15:32:13.075 [http-nio-8080-exec-1] INFO  o.e.s.ThreadLocalService - serviceA UserData(id=16, name=albaba, objectMap=null, otherInfo=null)
2023-11-06 15:32:13.075 [http-nio-8080-exec-1] ERROR o.e.c.ThreadLocalController - threadLocal getUserData error
java.lang.IllegalStateException:  direct calls to SafeThreadLocal.get is forbidden,  please use SafeThreadLocal.callWithThrow or SafeThreadLocal.run 
	at org.example.util.SafeThreadLocal.check(SafeThreadLocal.java:62)
	at org.example.util.SafeThreadLocal.get(SafeThreadLocal.java:37)
	at org.example.controller.ThreadLocalController.threadLocal(ThreadLocalController.java:28)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)

咱们还可以跑一下单测

package org.example.util;

import org.example.domain.UserData;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class SafeThreadLocalTest {

    @Test
    public void test() {
        SafeThreadLocal.run(() -> {
            UserData userData = new UserData();
            userData.setId(16);
            // 放入上下文变量
            SafeThreadLocal.put(userData);
            SafeThreadLocal.run(() -> {
                // true; 内嵌层可访问外层上下文变量
                Assertions.assertEquals(16, SafeThreadLocal.get().getId());

                userData.setId(17);
                // 内层放入上下文变量
                SafeThreadLocal.put(userData);
            });

            // true
            Assertions.assertEquals(17, SafeThreadLocal.get().getId());
            // true; 内层存入的上下文变量, 在外层也可访问
        });
        // 抛IllegalStateException; 离开scope后, 变量自动清理, 无法访问到
        Assertions.assertThrows(IllegalStateException.class, SafeThreadLocal::get);
    }

}

大家思考一下,为什么需要NESTED_LEVEL的逻辑呢?
如果没有的逻辑,在ThreadLocalService里面,serviceA调用了serviceB之后,还能访问serviceA里面的数据么?
准确来说,就是在下面serviceA 标注的地方 还能拿到数据么?
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值