ThreadLocal与InheritableThreadLocal
ThreadLocal在我们平时的开发中很常见,拥有线程级别的变量共享,但是现在的项目都是跨线程的调用,如果主线程创建了另一个线程(父子线程),另一个线程还能拿到主线程的数据吗?这时候ThreadLocal就力不从心了,还好jdk提供了InheritableThreadLocal类,我们稍微讲下InheritableThreadLocal在跨线程间变量传递的原理。
在Thread类里,除了threadLocals 变量,还有一个inheritableThreadLocals变量,两者类型一模一样。
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
/*
* InheritableThreadLocal values pertaining to this thread. This map is
* maintained by the InheritableThreadLocal class.
*/
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
inheritableThreadLocals在我们使用ThreadLocal时是用不上的,但是在新建一个Thread的时候,我们可以看下Thread的构造函数,有一行很关键的代码:
先获取当前执行线程,也就是我们所说的父线程。然后判断父线程的inheritableThreadLocals变量是否为空,不为空?那就把父线程的inheritableThreadLocals变量拷贝一份给子线程
Thread parent = currentThread();
if (parent.inheritableThreadLocals != null)
this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
经过这么一遭,当你使用InheritableThreadLocal这个类的时候,父子线程就能共享同一变量了。
但是现在又有另一个问题,现在的多线程编程很少自己去new 一个 Thread, 而是使用了线程池,线程池里的线程是多次复用的,InheritableThreadLocal是通过Thread的构造函数完成变量传递,显然线程池的情况无法满足功能。这时候,就需要TransmittableThreadLocal登场了。
TransmittableThreadLocal
TransmittableThreadLocal是阿里推出的工具库,专门解决线程复用情况下变量传递问题。
我随便写了个测试类,先看下效果
public class TransmittableThreadTest {
static ThreadPoolExecutor threadPoolExecutor;
public static void main(String[] args) {
threadPoolExecutor = new ThreadPoolExecutor(1,
1,
0,
TimeUnit.HOURS,
new ArrayBlockingQueue(100),
(x) -> {
return new Thread(x);
});
Thread.yield();
TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();
// =====================================================
threadPoolExecutor.submit(() -> {
System.out.println("线程池只有真正submit的时候才初始化线程池(懒加载),所以这里先初始化线程池");
});
Thread.yield();
// 在父线程中设置
context.set("value-set-in-parent-1");
Runnable task = () -> {
String value = context.get();
System.out.println("我在子线程中拿到了值: " + value);
};
// 额外的处理,生成修饰了的对象ttlRunnable
Runnable ttlRunnable = TtlRunnable.get(task);
threadPoolExecutor.submit(ttlRunnable);
// =====================================================
Thread.yield();
}
}
执行结果
线程池只有真正submit的时候才初始化线程池(懒加载),所以这里先初始化线程池
我在子线程中拿到了值: value-set-in-parent-1
在讲源码前,我们先介绍下框架重要的概念:
CRR(Capture/Replay/Restore)是一个面向上下文传递设计的流程,通过这个流程的分析可以保证/证明 正确性。
简单来说,CRR操作实现了线程池状态下,上下文的传递和逻辑结束完之后的还原现场(还原现场很重要,避免可能存在的bug),下面会对各个阶段逐一分析下原理。
capture方法:抓取线程(线程A)的所有TTL值。
replay方法:在另一个线程(线程B)中,回放在capture方法中抓取的TTL值,并返回 回放前TTL值的备份
restore方法:恢复线程B执行replay方法之前的TTL值(即备份)
线程池初始化阶段
接下来我们就是看源码是怎么执行的。先从包装一个线程池开始。
要对线程池有效,初始化线程池的时候肯定要调用TTL的方法TtlExecutors.getTtlExecutor()
包装一下。
@Bean("sendPushDeliveryTaskExecutor")
public Executor sendPushDeliveryTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 设置核心线程数
executor.setCorePoolSize(8);
// 设置最大线程数
executor.setMaxPoolSize(8);
// 设置队列容量
executor.setQueueCapacity(1000);
// 设置默认线程名称
executor.setThreadNamePrefix("xxx-thread-");
// 设置拒绝策略
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
executor.initialize();
return TtlExecutors.getTtlExecutor(executor);
}
不过这个方法倒没有做什么事,只是用一个ExecutorTtlWrapper
包装了一下 ThreadPoolTaskExecutor
,初始化过程其他事没有做。注意,这个ExecutorTtlWrapper
实现了Executor
接口,也就是说往线程池丢任务都其实执行的是ExecutorTtlWrapper
的 execute
方法,给TTL实现跨线程池传递上下文功能提供了可行性。
class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
private final Executor executor;
protected final boolean idempotent;
ExecutorTtlWrapper(@NonNull Executor executor, boolean idempotent) {
this.executor = executor;
this.idempotent = idempotent;
}
....
}
set阶段
TransmittableThreadLocal<Object> threadLocal = new TransmittableThreadLocal<>();
threadLocal.set("a");
正常使用来说父线程会set一个上下文value,里面关键的就来了
首先 super.set(value)
方法会因为TransmittableThreadLocal
父类是InheritableThreadLocal
,所以会将值存在当前线程的 inheritableThreadLocals
变量。
接着,addThisToHolder
方法会将TransmittableThreadLocal
对象加入到 TransmittableThreadLocal
类里的一个静态变量 holder
中。 这个 holder
专门存放 TransmittableThreadLocal
对象,为了方便,下面直接把TransmittableThreadLocal
称为 TTL
。 看到这里可能有点绕,这里的步骤其实就是做了两件事:
- 给当前线程的InheritableThreadLocal变量赋值。
- 将
TTL
放到一个静态变量holder
map中。
@Override
public final void set(T value) {
if (!disableIgnoreNullValueSemantics && null == value) {
// may set null to remove value
remove();
} else {
super.set(value);
addThisToHolder();
}
}
线程池execute阶段(capture阶段)
接下来父线程塞值进去后,开始往线程池提交Runnable任务,之前初始化阶段
我们知道线程池被包装成ExecutorTtlWrapper
对象,那执行的也是ExecutorTtlWrapper
的execute
方法
@Override
public void execute(@NonNull Runnable command) {
executor.execute(TtlRunnable.get(command, false, idempotent));
}
出现一个TtlRunnable
对象,看来Runnable
在给线程池前,也被封装了一层。
继续跟进源码,前面没啥看的,最后还是new
了一个TtlRunnable
对象。
@Nullable
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun,
boolean idempotent) {
if (null == runnable) return null;
if (runnable instanceof TtlEnhanced) {
// avoid redundant decoration, and ensure idempotency
if (idempotent) return (TtlRunnable) runnable;
else throw new IllegalStateException("Already TtlRunnable!");
}
return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}
这里关键的是capture()
方法。
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
capture方法里面轮询了一个transmitteeSet
集合,并且调用了transmittee.capture()
方法,我们肯定想知道transmitteeSet
是个啥。
@NonNull
public static Object capture() {
final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(transmitteeSet.size());
for (Transmittee<Object, Object> transmittee : transmitteeSet) {
try {
transmittee2Value.put(transmittee, transmittee.capture());
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee +
"(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
}
}
}
return new Snapshot(transmittee2Value);
}
transmitteeSet
其实也是静态变量,在静态代码块的时候就初始化了,ttlTransmittee
其实是个匿名内部类,定义了transmittee.capture()
,transmittee.replay()
,transmittee.restore()
等方法的行为,其中transmittee.capture()
就是在轮询之前提到的 静态变量 holder
,把当前线程操作过的TTL对象
取出来,并把上下文信息生成Snapshot
对象返回。换句话说,capture
的行为就是在生成父线程的上下文快照,给子线程使用。
static {
registerTransmittee(ttlTransmittee);
}
public static <C, B> boolean registerTransmittee(@NonNull Transmittee<C, B> transmittee) {
return transmitteeSet.add((Transmittee<Object, Object>) transmittee);
}
private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>(holder.get().size());
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
final HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>(holder.get().size());
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> clear() {
return replay(new HashMap<>(0));
}
@Override
public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
// call afterExecute callback
doExecuteCallback(false);
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backup.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// restore TTL values
setTtlValuesTo(backup);
}
};
至此,TtlRunnable
的capturedRef
变量就拥有了父线程上下文的快照信息,然后通过 execute
将TtlRunnable
交给了线程池,父线程已经完成了自己使命。
executor.execute(TtlRunnable.get(command, false, idempotent));
replay阶段
在线程池中,子线程会调用TtlRunnable
的run
方法,captured
就是父线程的上下文快照,这里面第一个重要的步骤是replay(captured)
。
@Override
public void run() {
final Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
final Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}
父线程传过来的快照会被再次轮询并回放到当前线程,关键方法是transmittee.replay(transmitteeCaptured)
。
@NonNull
public static Object replay(@NonNull Object captured) {
final Snapshot capturedSnapshot = (Snapshot) captured;
final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(capturedSnapshot.transmittee2Value.size());
for (Map.Entry<Transmittee<Object, Object>, Object> entry : capturedSnapshot.transmittee2Value.entrySet()) {
Transmittee<Object, Object> transmittee = entry.getKey();
try {
Object transmitteeCaptured = entry.getValue();
transmittee2Value.put(transmittee, transmittee.replay(transmitteeCaptured));
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "exception when Transmitter.replay for transmittee " + transmittee +
"(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
}
}
}
return new Snapshot(transmittee2Value);
}
transmittee.replay(transmitteeCaptured)
其实还是调用之前的ttlTransmittee
的replay()
方法,代码再粘贴一下。这里面主要做了3个步骤:
- 当前线程的上下文先做一个
backup
。 - 将
captured
中不存在,holder
中存在的上下文remove
处理, setTtlValuesTo
很明了,会把父线程的上下文快照赋值给当前线程。
其中第二点这么做是为什么?因为不这么做,就没有正确回放值。
还不明白的话举个例子: 比如 原来父线程的快照是 a=1,b=2
,子线程的上下文原来就有值z=100
。
如果不删除z的话,回放运行时,就有值a=1,b=2,z=100
,不符合正确回放的预期。
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
final HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>(holder.get().size());
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
至此,回放(replay)
部分也搞定了,其实到这主功能已经实现了,但是子线程运行完之后,还做了一步操作,恢复(restore)
操作。
restore阶段
runnable.run()
就是在执行业务逻辑,执行完后,finally
块里又执行了 restore
方法。
@Override
public void run() {
final Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
final Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}
到了最后一个阶段,也就是子线程已经run
完成,replay
前已经backup
了当前线程的上下文,现在需要做恢复,就像大家去看一场演唱会,演出结束了,请带走所有的垃圾,恢复如初。
public static void restore(@NonNull Object backup) {
final Snapshot backupSnapshot = (Snapshot) backup;
restoreTtlValues(backupSnapshot.ttl2Value);
restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
这里虽然原理很简单,但是有一个疑问,这里为什么要做重放的操作?
原因可以从一个特殊例子说起,当线程池满负载运行(等待队列也满了)之后,且线程池的拒绝策略采用的是CallerRunsPolicy的情况下,主线程就会执行子线程的任务,也就是说,已经没有主子线程之分,如果不采用restore(恢复)机制,中途对主线程的上下文修改就是永久性的,主线程的上下文就被污染了,使用就会出现bug,无法保证回放恢复的正确性。
结语
感谢 项目gitlab 里一些issue来解答我的疑惑,才有这篇文章。
修订
2022.10.31 重新完善了这篇文章,原理讲的更加清楚