InheritableThreadLocal 使用的问题及解决办法

问题背景

最近公司负责的项目总是出现莫名其妙的串号的问题,具体表现是一次请求里日志却记录的多个userId(一次请都只会有一个userId),看了很久的报文,分析了很久的代码,终于有了点眉目。

问题分析

这类问题主要出现在一个主任务启动的多个线程并发执行造成的,执行分析使用的是
InheritableThreadLocal

private static class InheritableThreadLocalContext extends InheritableThreadLocal<ThreadLocalContext> {
        private InheritableThreadLocalContext() {
        }

        protected ThreadLocalContext initialValue() {
            return ThreadLocalContext.init();
        }

        protected ThreadLocalContext childValue(ThreadLocalContext parentValue) {
            ThreadLocalContext childContext = ThreadLocalContext.init();
            if (parentValue != null) {
                childContext.logContext = parentValue.logContext;
                childContext.scopeContext = parentValue.scopeContext;
            }

            return childContext;
        }
    }

InheritableThreadLocal是ThreadLocal的子类,比ThreadLocal优秀一点就是可以进行主子线程间ThreadLocalMap上下文拷贝。
源码如下:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    /**
     * Computes the child's initial value for this inheritable thread-local
     * variable as a function of the parent's value at the time the child
     * thread is created.  This method is called from within the parent
     * thread before the child is started.
     * <p>
     * This method merely returns its input argument, and should be overridden
     * if a different behavior is desired.
     *
     * @param parentValue the parent thread's value
     * @return the child thread's initial value
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    /**
     * Get the map associated with a ThreadLocal.
     *
     * @param t the current thread
     */
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    /**
     * Create the map associated with a ThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the table.
     */
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

但是有一个特点,就是把主线程ThreadLocalMap拷贝到子线程是在线程初始化的情况下进行的。
在这里插入图片描述
InheritableThreadLocal 赋值是在Thread.init方法中执行的。
如果主线程一直存且里面ThreadLocalMap内容变化了,则不会拷贝到子线程的ThreadLocalMap中。

而我们的代码在开启子线程的时候使用了线程池。线程创建后一般不会销毁,因此也就造成InheritableThreadLocal 传递上下文串号了。

验证问题

下面我写了一个简单的例子验证以下上面的分析:

public class InheritableThreadLocalTest {

    public static void main(String[] args) throws InterruptedException {

        final InheritableThreadLocal<Span> inheritableThreadLocal=new InheritableThreadLocal<>();
        inheritableThreadLocal.set(new Span("xiexiexie"));
       //输出xiexiexie
        Span span1 = inheritableThreadLocal.get();
        System.out.println(span1.name);
        ExecutorService es= Executors.newFixedThreadPool(1);
        es.execute(()->{
            System.out.println("====1====");
            //输出 xiexiexie
            Span span2 = inheritableThreadLocal.get();
            System.out.println(span2.name);
            inheritableThreadLocal.set(new Span("qiqiqi"));
            //输出 qiqiqi
            span2 = inheritableThreadLocal.get();
            System.out.println(span2.name);
        });
        TimeUnit.SECONDS.sleep(1);

        inheritableThreadLocal.set(new Span("setsetset"));
        es.execute(()->{
                System.out.println("====2====");
                //输出qiqiqi
            Span span2 = inheritableThreadLocal.get();
            System.out.println(span2.name);
                inheritableThreadLocal.set(new Span("xxx"));
                //输出qiqiqi
            span2 = inheritableThreadLocal.get();
            System.out.println(span2.name);
            });

        TimeUnit.SECONDS.sleep(1);
        System.out.println("====0====");
        //输出xiexiexie
        Span span = inheritableThreadLocal.get();
        System.out.println(span.name);

    }

    static class Span{
        public String name;
        public int age;
        public Span(String name){
            this.name=name;
        }
    }
}

执行一下结果如下:

xiexiexie
====1====
xiexiexie
qiqiqi
====2====
qiqiqi
xxx
====0====
setsetset

简单分析,主线程初始值是"xiexiexie",子线程1的inheritableThreadLocal打出来的也是"xiexiexie";子线程1改成了"qiqiqi";子线程2的inheritableThreadLocal打出来的也是"qiqiqi",这个主要是inheritableThreadLocal赋值是浅复制,子线程1变化了,其实也是变化了主线程的ThreadLocalMap。这是主线程把值换成了“setsetset”,但是子线程2的打出来的依旧是上一次的结果"qiqiqi"。
造成这个结果的原因:只有创建子线程的时候才会设置子线程的inheritableTreadLocals值,假如第一次提交的任务是A,第二次是B,B任务提交任务时使用的是A的任务的缓存线程,A任务执行时已经重新set了InheritableThreadLocals,值已经变为qiqiqi,B任务再次获取时候直接从t.inheritableThreadLocals中获取,所以获得的是A任务提交的值,而不是父线程的值(父线程值没有改变的原因是子线程set的值,只会set到子线程对应的t.inheritableThreadLocals中,不会影响父线程的inheritableThreadLocals)

解决思路

在submit新任务的时候在重新copy父线程的所有的Entry,然后重新给t.inheritableThreadLocals赋值,这样就解决线程池中每一个新的任务都能获得父线程中的ThreadLocal的值,而不受其他任务影响,因为在生命周期完成时候会自动clear所有数据。

解决方案

自定义RunTask类

自定一个RunTask类,使用反射加代理的方式来实现业务,主线程存在InheritableThreaadLocal中的值间接复制,详细如下:

定义一个InheritableTask抽象类,这个类实现了Runable接口,并定义一个runTask抽象方法,当开发者需要面对线程池,获取InheritableThreadLocal值的场景提交任务只需要集成InheritableTask类,实现runTask方法即可。
在创建任务时,也就是InheritableTask构造方法中,通过反射获取提交任务的业务线程的inheritableLocals属性,然后复制一份,暂存到当前的task的inheritableThreadLocalsObj属性找那个
线程池在执行该任务时,其实就是去掉用run()方法,在执行run方法时,先将inheritableThreadLocalsObj属性复制给当前执行任务的那个业务线程的inheritableThreadLocals属性值,然后再去执行runTask()方法,就是真正的业务逻辑,最后finally清理掉执行当前业务的线程的inheritableThreadLocals属性。
详细代码如下:

public abstract class InheritableTask implements Runnable {
    private Object inheritableThreadLocalsObj;

    public InheritableTask() {
        try {
            //获取当前业务线程
            Thread currentThread = Thread.currentThread();
            //获取inheritableThreadLocals属性值
            Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
            inheritableThreadLocalsField.setAccessible(true);
            //得到当前线程inheritableThreadLocals的属性值
            Object threadLocalMapObj = inheritableThreadLocalsField.get(currentThread);
            if (null != threadLocalMapObj) {
                //获取字段的类型
                Class<?> threadLocalMapClazz = inheritableThreadLocalsField.getType();
                //获取ThreadLocal中的createInheritedMap方法
                Method method = ThreadLocal.class.getDeclaredMethod("createInheritedMap", threadLocalMapClazz);
                method.setAccessible(true);
                //调用createInheritedMap方法,重新创建一个新的inheritableThreadLocals,并且将这个值保存
                this.inheritableThreadLocalsObj = method.invoke(ThreadLocal.class, threadLocalMapObj);
            }

        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public void run() {
        //此处获取处理当前业务的线程,也就是线程池中的线程
        Thread currentThread = Thread.currentThread();
        Field field = null;
        try {
            //获取inheritableThreadLocals属性
            field = Thread.class.getDeclaredField("inheritableThreadLocals");
            //设置权限
            field.setAccessible(true);
            if (this.inheritableThreadLocalsObj != null) {
                //将暂存值,赋值给currentThread
                field.set(currentThread, this.inheritableThreadLocalsObj);
                inheritableThreadLocalsObj = null;
            }
            //执行任务
            runTask();
        } catch (Exception e) {
            throw new IllegalStateException(e);
        } finally {
            try {
                //最后将线程的InheritableThreadLocals设置为null
                if (field != null) {
                    field.set(currentThread, null);
                }
            } catch (IllegalAccessException e) {
                throw new IllegalStateException(e);
            }
        }
    }

    /**
     * 代理方法这个方法处理业务逻辑
     */
    public abstract void runTask();

}

下面是测试用例:

public class ThreadLocalTest {

    public static void main(String[] args) throws InterruptedException {
        InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();
        inheritableThreadLocal.set("qis");
        System.out.println(inheritableThreadLocal.get());
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        executorService.execute(() -> {
            System.out.println(inheritableThreadLocal.get());
            inheritableThreadLocal.set("qishuo");
            System.out.println(inheritableThreadLocal.get());
        });
        Thread.sleep(1000);
        System.out.println("----------------");
        executorService.execute(() -> {
            System.out.println(inheritableThreadLocal.get());
        });
        Thread.sleep(1000);
        System.out.println("--------------分隔新以上是没有使用InheritableTask----------------");

        executorService.submit(new InheritableTask() {
            @Override
            public void runTask() {
                System.out.println(inheritableThreadLocal.get());
                inheritableThreadLocal.set("qishuo");
                System.out.println(inheritableThreadLocal.get());
            }
        });
        Thread.sleep(1000);
        System.out.println("----------------");
        executorService.submit(new InheritableTask() {
            @Override
            public void runTask() {
                System.out.println(inheritableThreadLocal.get());
            }
        });
    }
}

输出结果:

qis
qis
qishuo
----------------
qishuo
--------------分隔新以上是没有使用InheritableTask----------------
qis
qishuo
----------------
qis

这样就解决了在线程池场景下的InheritableThreadLocal无效的问题,然而反射比较耗性能,一般优化反射的两种方式,一种使用缓存,一种使用性能较高的反射工具比如RefelectASM类。

下面展示使用缓存的实现:
public abstract class InheritableTaskWithCache implements Runnable {
    private Object threadLocalsMapObj;
    private static volatile Field inheritableThreadLocalsField;
    private static volatile Class threadLocalMapClazz;
    private static volatile Method createInheritedMapMethod;
    private static final Object accessLock = new Object();

    public InheritableTaskWithCache() {
        try {
            Thread currentThread = Thread.currentThread();
            Field field = getInheritableThreadLocalsField();
            //得到当前线程的inheritableThreadLocals的值ThreadLocalMap
            Object threadLocalsMapObj = field.get(currentThread);
            if (null != threadLocalsMapObj) {
                Class threadLocalMapClazz = getThreadLocalMapClazz();
                Method method = getCreateInheritedMapMethod(threadLocalMapClazz);
                //创建一个新的ThreadLocalMap
                this.threadLocalsMapObj = method.invoke(ThreadLocal.class, threadLocalsMapObj);
            }
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    private Field getInheritableThreadLocalsField() {
        if (null == inheritableThreadLocalsField) {
            synchronized (accessLock) {
                if (null == inheritableThreadLocalsField) {
                    try {
                        Field field = Thread.class.getDeclaredField("inheritableThreadLocals");
                        field.setAccessible(true);
                        inheritableThreadLocalsField = field;
                    } catch (Exception e) {
                        throw new IllegalStateException(e);
                    }
                }
            }
        }
        return inheritableThreadLocalsField;
    }

    private Method getCreateInheritedMapMethod(Class threadLocalMapClazz) {
        if (null != threadLocalMapClazz && null == createInheritedMapMethod) {
            synchronized (accessLock) {
                if (null == createInheritedMapMethod) {
                    try {
                        Method method = ThreadLocal.class.getDeclaredMethod("createInheritedMap", threadLocalMapClazz);
                        method.setAccessible(true);
                        createInheritedMapMethod = method;
                    } catch (Exception e) {
                        throw new IllegalStateException(e);
                    }
                }
            }
        }
        return createInheritedMapMethod;
    }

    private Class getThreadLocalMapClazz() {
        if (null == inheritableThreadLocalsField) {
            return null;
        }
        if (null == threadLocalMapClazz) {
            synchronized (accessLock) {
                if (null == threadLocalMapClazz) {
                    threadLocalMapClazz = inheritableThreadLocalsField.getType();
                }
            }
        }
        return threadLocalMapClazz;
    }

    /**
     * 代理方法这个方法处理业务逻辑
     */
    protected abstract void runTask();

    @Override
    public void run() {
        Thread currentThread = Thread.currentThread();
        Field field = getInheritableThreadLocalsField();
        try {
            if (null != threadLocalsMapObj && null != field) {
                field.set(currentThread, threadLocalsMapObj);
                threadLocalsMapObj = null;
            }
            runTask();
        } catch (Exception e) {
            throw new IllegalStateException(e);
        } finally {
            try {
                if (field != null) {
                    field.set(currentThread, null);
                }
            } catch (Exception e) {
                System.out.println(e.toString());
            }
        }
    }
}

综上,通过一个抽象的InheritableTask解决了线程池场景下InheritableThreadLocal失效问题。

总结

InheritableThreadLocal在线程池中无效的原因是只有在创建线程Thread时才会去赋值父线程的InheritableThreadLocal中的值,而线程池场景下,主业务线程仅仅是提交任务的队列中的
如果要解决这个问题,可以自定义一个RunTask类,通过反射加代理的方式来实现业务主线程存在InheritableThreadLocal中值的间接复制,或者使用阿里开源的transmittable-thread-local。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值