一窥TransmittableThreadLocal

ThreadLocal与InheritableThreadLocal

ThreadLocal在我们平时的开发中很常见,拥有线程级别的变量共享,但是现在的项目都是跨线程的调用,如果主线程创建了另一个线程(父子线程),另一个线程还能拿到主线程的数据吗?这时候ThreadLocal就力不从心了,还好jdk提供了InheritableThreadLocal类,我们稍微讲下InheritableThreadLocal在跨线程间变量传递的原理。

在Thread类里,除了threadLocals 变量,还有一个inheritableThreadLocals变量,两者类型一模一样。

    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

inheritableThreadLocals在我们使用ThreadLocal时是用不上的,但是在新建一个Thread的时候,我们可以看下Thread的构造函数,有一行很关键的代码:

先获取当前执行线程,也就是我们所说的父线程。然后判断父线程的inheritableThreadLocals变量是否为空,不为空?那就把父线程的inheritableThreadLocals变量拷贝一份给子线程

Thread parent = currentThread();

if (parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

经过这么一遭,当你使用InheritableThreadLocal这个类的时候,父子线程就能共享同一变量了。

但是现在又有另一个问题,现在的多线程编程很少自己去new 一个 Thread, 而是使用了线程池,线程池里的线程是多次复用的,InheritableThreadLocal是通过Thread的构造函数完成变量传递,显然线程池的情况无法满足功能。这时候,就需要TransmittableThreadLocal登场了。

TransmittableThreadLocal

TransmittableThreadLocal是阿里推出的工具库,专门解决线程复用情况下变量传递问题。

我随便写了个测试类,先看下效果

public class TransmittableThreadTest {
    static ThreadPoolExecutor threadPoolExecutor;

    public static void main(String[] args) {

        threadPoolExecutor = new ThreadPoolExecutor(1,
                1,
                0,
                TimeUnit.HOURS,
                new ArrayBlockingQueue(100),
                (x) -> {
                    return new Thread(x);
                });

        Thread.yield();

        TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();

// =====================================================

        threadPoolExecutor.submit(() -> {
            System.out.println("线程池只有真正submit的时候才初始化线程池(懒加载),所以这里先初始化线程池");
        });

        Thread.yield();

        // 在父线程中设置
        context.set("value-set-in-parent-1");

        Runnable task = () -> {
            String value = context.get();
            System.out.println("我在子线程中拿到了值: " + value);
        };
        // 额外的处理,生成修饰了的对象ttlRunnable
        Runnable ttlRunnable = TtlRunnable.get(task);
        threadPoolExecutor.submit(ttlRunnable);
// =====================================================
        Thread.yield();
    }
}

执行结果

线程池只有真正submit的时候才初始化线程池(懒加载),所以这里先初始化线程池
我在子线程中拿到了值: value-set-in-parent-1

在讲源码前,我们先介绍下框架重要的概念:

CRR(Capture/Replay/Restore)是一个面向上下文传递设计的流程,通过这个流程的分析可以保证/证明 正确性。 简单来说,CRR操作实现了线程池状态下,上下文的传递和逻辑结束完之后的还原现场(还原现场很重要,避免可能存在的bug),下面会对各个阶段逐一分析下原理。

capture方法:抓取线程(线程A)的所有TTL值。
replay方法:在另一个线程(线程B)中,回放在capture方法中抓取的TTL值,并返回 回放前TTL值的备份
restore方法:恢复线程B执行replay方法之前的TTL值(即备份)
线程池初始化阶段

接下来我们就是看源码是怎么执行的。先从包装一个线程池开始。

要对线程池有效,初始化线程池的时候肯定要调用TTL的方法TtlExecutors.getTtlExecutor() 包装一下。

	@Bean("sendPushDeliveryTaskExecutor")
	public Executor sendPushDeliveryTaskExecutor() {
		ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
		// 设置核心线程数
		executor.setCorePoolSize(8);
		// 设置最大线程数
		executor.setMaxPoolSize(8);
		// 设置队列容量
		executor.setQueueCapacity(1000);
		// 设置默认线程名称
		executor.setThreadNamePrefix("xxx-thread-");
		// 设置拒绝策略
		executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
		executor.initialize();
		return TtlExecutors.getTtlExecutor(executor);
	}

不过这个方法倒没有做什么事,只是用一个ExecutorTtlWrapper 包装了一下 ThreadPoolTaskExecutor ,初始化过程其他事没有做。注意,这个ExecutorTtlWrapper 实现了Executor接口,也就是说往线程池丢任务都其实执行的是ExecutorTtlWrapperexecute方法,给TTL实现跨线程池传递上下文功能提供了可行性。

class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
    private final Executor executor;
    protected final boolean idempotent;

    ExecutorTtlWrapper(@NonNull Executor executor, boolean idempotent) {
        this.executor = executor;
        this.idempotent = idempotent;
    }
    ....
}
set阶段
TransmittableThreadLocal<Object> threadLocal = new TransmittableThreadLocal<>();
threadLocal.set("a");

正常使用来说父线程会set一个上下文value,里面关键的就来了

首先 super.set(value)方法会因为TransmittableThreadLocal父类是InheritableThreadLocal,所以会将值存在当前线程的 inheritableThreadLocals 变量。

接着,addThisToHolder方法会将TransmittableThreadLocal 对象加入到 TransmittableThreadLocal 类里的一个静态变量 holder 中。 这个 holder 专门存放 TransmittableThreadLocal 对象,为了方便,下面直接把TransmittableThreadLocal称为 TTL。 看到这里可能有点绕,这里的步骤其实就是做了两件事:

  1. 给当前线程的InheritableThreadLocal变量赋值。
  2. TTL放到一个静态变量 holder map中。
@Override
    public final void set(T value) {
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            super.set(value);
            addThisToHolder();
        }
    }
线程池execute阶段(capture阶段)

接下来父线程塞值进去后,开始往线程池提交Runnable任务,之前初始化阶段我们知道线程池被包装成ExecutorTtlWrapper对象,那执行的也是ExecutorTtlWrapperexecute方法

    @Override
    public void execute(@NonNull Runnable command) {
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }

出现一个TtlRunnable对象,看来Runnable 在给线程池前,也被封装了一层。

继续跟进源码,前面没啥看的,最后还是new了一个TtlRunnable对象。

@Nullable
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, 
                              boolean idempotent) {
    if (null == runnable) return null;

    if (runnable instanceof TtlEnhanced) {
        // avoid redundant decoration, and ensure idempotency
        if (idempotent) return (TtlRunnable) runnable;
        else throw new IllegalStateException("Already TtlRunnable!");
    }
    return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

这里关键的是capture()方法。

    private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        this.capturedRef = new AtomicReference<>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

capture方法里面轮询了一个transmitteeSet集合,并且调用了transmittee.capture()方法,我们肯定想知道transmitteeSet是个啥。

        @NonNull
        public static Object capture() {
            final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(transmitteeSet.size());
            for (Transmittee<Object, Object> transmittee : transmitteeSet) {
                try {
                    transmittee2Value.put(transmittee, transmittee.capture());
                } catch (Throwable t) {
                    if (logger.isLoggable(Level.WARNING)) {
                        logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee +
                                "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
                    }
                }
            }
            return new Snapshot(transmittee2Value);
        }

transmitteeSet其实也是静态变量,在静态代码块的时候就初始化了,ttlTransmittee其实是个匿名内部类,定义了transmittee.capture()transmittee.replay()transmittee.restore()等方法的行为,其中transmittee.capture()就是在轮询之前提到的 静态变量 holder,把当前线程操作过的TTL对象取出来,并把上下文信息生成Snapshot对象返回。换句话说,capture 的行为就是在生成父线程的上下文快照,给子线程使用。

        static {
            registerTransmittee(ttlTransmittee);
        }
  
          public static <C, B> boolean registerTransmittee(@NonNull Transmittee<C, B> transmittee) {
            return transmitteeSet.add((Transmittee<Object, Object>) transmittee);
        }
        
 private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
                new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
                        final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>(holder.get().size());
                        for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
                            ttl2Value.put(threadLocal, threadLocal.copyValue());
                        }
                        return ttl2Value;
                    }

                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
                        final HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>(holder.get().size());

                        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                            TransmittableThreadLocal<Object> threadLocal = iterator.next();

                            // backup
                            backup.put(threadLocal, threadLocal.get());

                            // clear the TTL values that is not in captured
                            // avoid the extra TTL values after replay when run task
                            if (!captured.containsKey(threadLocal)) {
                                iterator.remove();
                                threadLocal.superRemove();
                            }
                        }

                        // set TTL values to captured
                        setTtlValuesTo(captured);

                        // call beforeExecute callback
                        doExecuteCallback(true);

                        return backup;
                    }

                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> clear() {
                        return replay(new HashMap<>(0));
                    }

                    @Override
                    public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
                        // call afterExecute callback
                        doExecuteCallback(false);

                        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                            TransmittableThreadLocal<Object> threadLocal = iterator.next();

                            // clear the TTL values that is not in backup
                            // avoid the extra TTL values after restore
                            if (!backup.containsKey(threadLocal)) {
                                iterator.remove();
                                threadLocal.superRemove();
                            }
                        }

                        // restore TTL values
                        setTtlValuesTo(backup);
                    }
                };

至此,TtlRunnablecapturedRef变量就拥有了父线程上下文的快照信息,然后通过 executeTtlRunnable交给了线程池,父线程已经完成了自己使命。

executor.execute(TtlRunnable.get(command, false, idempotent));
replay阶段

在线程池中,子线程会调用TtlRunnablerun方法,captured 就是父线程的上下文快照,这里面第一个重要的步骤是replay(captured)

    @Override
    public void run() {
        final Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }

        final Object backup = replay(captured);
        try {
            runnable.run();
        } finally {
            restore(backup);
        }
    }

父线程传过来的快照会被再次轮询并回放到当前线程,关键方法是transmittee.replay(transmitteeCaptured)

        @NonNull
        public static Object replay(@NonNull Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;

            final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(capturedSnapshot.transmittee2Value.size());
            for (Map.Entry<Transmittee<Object, Object>, Object> entry : capturedSnapshot.transmittee2Value.entrySet()) {
                Transmittee<Object, Object> transmittee = entry.getKey();
                try {
                    Object transmitteeCaptured = entry.getValue();
                    transmittee2Value.put(transmittee, transmittee.replay(transmitteeCaptured));
                } catch (Throwable t) {
                    if (logger.isLoggable(Level.WARNING)) {
                        logger.log(Level.WARNING, "exception when Transmitter.replay for transmittee " + transmittee +
                                "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
                    }
                }
            }
            return new Snapshot(transmittee2Value);
        }

transmittee.replay(transmitteeCaptured)其实还是调用之前的ttlTransmitteereplay()方法,代码再粘贴一下。这里面主要做了3个步骤:

  1. 当前线程的上下文先做一个backup
  2. captured中不存在,holder中存在的上下文 remove处理,
  3. setTtlValuesTo 很明了,会把父线程的上下文快照赋值给当前线程。

其中第二点这么做是为什么?因为不这么做,就没有正确回放值。
还不明白的话举个例子: 比如 原来父线程的快照是 a=1,b=2,子线程的上下文原来就有值z=100
如果不删除z的话,回放运行时,就有值a=1,b=2,z=100,不符合正确回放的预期。

                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
                        final HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>(holder.get().size());

                        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                            TransmittableThreadLocal<Object> threadLocal = iterator.next();

                            // backup
                            backup.put(threadLocal, threadLocal.get());

                            // clear the TTL values that is not in captured
                            // avoid the extra TTL values after replay when run task
                            if (!captured.containsKey(threadLocal)) {
                                iterator.remove();
                                threadLocal.superRemove();
                            }
                        }

                        // set TTL values to captured
                        setTtlValuesTo(captured);

                        // call beforeExecute callback
                        doExecuteCallback(true);

                        return backup;
                    }

至此,回放(replay)部分也搞定了,其实到这主功能已经实现了,但是子线程运行完之后,还做了一步操作,恢复(restore)操作。

restore阶段

runnable.run() 就是在执行业务逻辑,执行完后,finally 块里又执行了 restore 方法。

    @Override
    public void run() {
        final Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }

        final Object backup = replay(captured);
        try {
            runnable.run();
        } finally {
            restore(backup);
        }
    }

到了最后一个阶段,也就是子线程已经run完成,replay前已经backup了当前线程的上下文,现在需要做恢复,就像大家去看一场演唱会,演出结束了,请带走所有的垃圾,恢复如初。

        public static void restore(@NonNull Object backup) {
            final Snapshot backupSnapshot = (Snapshot) backup;
            restoreTtlValues(backupSnapshot.ttl2Value);
            restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
        }

这里虽然原理很简单,但是有一个疑问,这里为什么要做重放的操作?

原因可以从一个特殊例子说起,当线程池满负载运行(等待队列也满了)之后,且线程池的拒绝策略采用的是CallerRunsPolicy的情况下,主线程就会执行子线程的任务,也就是说,已经没有主子线程之分,如果不采用restore(恢复)机制,中途对主线程的上下文修改就是永久性的,主线程的上下文就被污染了,使用就会出现bug,无法保证回放恢复的正确性。

结语

感谢 项目gitlab 里一些issue来解答我的疑惑,才有这篇文章。

修订

2022.10.31 重新完善了这篇文章,原理讲的更加清楚

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值