目录
概述
在多线程编程中,我们常常需要在线程之间传递上下文信息。Java 提供了 ThreadLocal
和 InheritableThreadLocal
来帮助管理线程局部变量,但在某些场景下,如线程池和异步执行中,这些工具存在一些局限性。让我们详细探讨这些问题的发展历程,并介绍最终的解决方案。
1. ThreadLocal
基本原理
ThreadLocal
提供了一种机制,使每个线程都可以有自己独立的变量副本,从而避免线程之间的变量共享和竞争。每个线程都有自己的 ThreadLocalMap
,ThreadLocal
的变量存储在其中。
使用示例
public class ThreadLocalExample {
private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();
public static void main(String[] args) {
ExecutorService executor = Executors.newFixedThreadPool(2);
threadLocal.set("ValueA");
System.out.println("主线程设置值为: ValueA");
executor.submit(() -> System.out.println("任务1 ThreadLocal 值: " + threadLocal.get()));
threadLocal.set("ValueB");
System.out.println("主线程设置值为: ValueB");
executor.submit(() -> System.out.println("任务2 ThreadLocal 值: " + threadLocal.get()));
executor.shutdown();
}
}
在这个示例中,主线程设置了 ThreadLocal
的值为 "ValueA",然后提交了一个任务给线程池。在任务提交后,主线程又将 ThreadLocal
的值设置为 "ValueB" 并提交了第二个任务。由于线程池中的线程会复用,两个任务可能会输出相同的值 "ValueB"。
局限性
- 线程池复用问题:在线程池中,线程会被重复使用。如果一个线程在一次任务中设置了
ThreadLocal
的值,那么该值可能会在后续任务中被误用,从而导致数据污染。 - 上下文丢失:
ThreadLocal
只在线程内有效,不能自动在父子线程之间传递数据。
2. InheritableThreadLocal
为了克服 ThreadLocal
不能在父子线程之间传递数据的问题,Java 引入了 InheritableThreadLocal
。
基本原理
InheritableThreadLocal
是 ThreadLocal
的一个子类,允许父线程的值自动传递给子线程。当创建一个新的子线程时,InheritableThreadLocal
会将父线程的值拷贝到子线程中。
使用示例
public class InheritableThreadLocalExample {
private static final InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();
public static void main(String[] args) {
ExecutorService executor = Executors.newFixedThreadPool(2);
inheritableThreadLocal.set("ValueA");
System.out.println("主线程设置值为: ValueA");
executor.submit(() -> System.out.println("任务1 InheritableThreadLocal 值: " + inheritableThreadLocal.get()));
inheritableThreadLocal.set("ValueB");
System.out.println("主线程设置值为: ValueB");
executor.submit(() -> System.out.println("任务2 InheritableThreadLocal 值: " + inheritableThreadLocal.get()));
executor.shutdown();
}
}
在这个示例中,主线程设置 InheritableThreadLocal
的值为 "ValueA" 并提交第一个任务。然后,主线程将 InheritableThreadLocal
的值改为 "ValueB" 并提交第二个任务。然而,第二个任务可能仍会打印 "ValueA" 的值,因为线程池中的线程复用了之前的线程上下文。
局限性
- 线程池复用问题:与
ThreadLocal
相同,在线程池中,InheritableThreadLocal
也存在数据污染的问题。子线程不会继承父线程的最新值,而是第一次创建线程时的值。 - 上下文更新问题:如果父线程更新了
InheritableThreadLocal
的值,已经存在的子线程不会反映这些变化。
3. TransmittableThreadLocal
随着应用程序复杂度的增加,尤其是在使用线程池和异步编程时,简单的 ThreadLocal
和 InheritableThreadLocal
已经不能满足需求。TransmittableThreadLocal
(TTL) 由阿里巴巴开源,旨在解决这些问题。
基本原理
TransmittableThreadLocal
通过捕获和恢复上下文信息,并包装线程池和任务,确保在线程执行任务前后进行上下文传递和清理。
使用示例
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.threadpool.TtlExecutors;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class TransmittableThreadLocalExample {
private static final TransmittableThreadLocal<String> transmittableThreadLocal = new TransmittableThreadLocal<>();
public static void main(String[] args) {
ExecutorService executor = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(2));
transmittableThreadLocal.set("值A");
System.out.println("主线程设置值为: 值A");
executeTasks(executor, "任务组1");
transmittableThreadLocal.set("值B");
System.out.println("主线程设置值为: 值B");
executeTasks(executor, "任务组2");
executor.shutdown();
}
private static void executeTasks(ExecutorService executor, String taskGroup) {
Runnable task = () -> {
String value = transmittableThreadLocal.get();
System.out.println(taskGroup + " - TransmittableThreadLocal 值: " + value);
if (!"值B".equals(value) && "任务组2".equals(taskGroup)) {
System.out.println(taskGroup + " - 数据污染检测!预期值为: 值B,但实际值为: " + value);
}
};
for (int i = 0; i < 5; i++) {
executor.submit(task);
}
}
}
在这个示例中,TTL 确保了在线程池中每个任务执行时,能够正确获取到当前线程的上下文数据,而不会受到之前任务的影响。
核心机制
- 捕获上下文:在任务提交前,捕获当前线程的所有
TransmittableThreadLocal
数据。 - 恢复上下文:在任务执行时,恢复捕获的上下文数据,确保子线程能够继承父线程的上下文。
- 清理上下文:在任务执行完毕后,清理子线程的上下文数据,避免数据污染和内存泄漏。
TransmittableThreadLocal的源码分析
TTL 的核心实现主要在 TransmittableThreadLocal.Transmitter
类中进行,它负责捕获、恢复和清理上下文信息。
完整代码示例
// TransmittableThreadLocal 类继承自 InheritableThreadLocal,并实现了 TtlCopier 接口
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
private static final Logger logger = Logger.getLogger(TransmittableThreadLocal.class.getName());
private final boolean disableIgnoreNullValueSemantics;
// 一个 InheritableThreadLocal 变量,用于存储当前线程的所有 TransmittableThreadLocal 对象
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
// 初始化值
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap<>();
}
// 复制父线程的值给子线程
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap<>(parentValue);
}
};
// 默认构造函数,初始化 disableIgnoreNullValueSemantics 为 false
public TransmittableThreadLocal() {
this(false);
}
// 带参构造函数,可以设置 disableIgnoreNullValueSemantics 的值
public TransmittableThreadLocal(boolean disableIgnoreNullValueSemantics) {
this.disableIgnoreNullValueSemantics = disableIgnoreNullValueSemantics;
}
// 创建一个带初始值的 TransmittableThreadLocal 实例
@NonNull
public static <S> TransmittableThreadLocal<S> withInitial(@NonNull Supplier<? extends S> supplier) {
if (supplier == null) {
throw new NullPointerException("supplier is null");
} else {
return new SuppliedTransmittableThreadLocal<>(supplier, null, null);
}
}
// 创建一个带初始值和复制器的 TransmittableThreadLocal 实例
@ParametersAreNonnullByDefault
@NonNull
public static <S> TransmittableThreadLocal<S> withInitialAndCopier(Supplier<? extends S> supplier, TtlCopier<S> copierForChildValueAndCopy) {
if (supplier == null) {
throw new NullPointerException("supplier is null");
} else if (copierForChildValueAndCopy == null) {
throw new NullPointerException("ttl copier is null");
} else {
return new SuppliedTransmittableThreadLocal<>(supplier, copierForChildValueAndCopy, copierForChildValueAndCopy);
}
}
// 创建一个带初始值和不同复制器的 TransmittableThreadLocal 实例
@ParametersAreNonnullByDefault
@NonNull
public static <S> TransmittableThreadLocal<S> withInitialAndCopier(Supplier<? extends S> supplier, TtlCopier<S> copierForChildValue, TtlCopier<S> copierForCopy) {
if (supplier == null) {
throw new NullPointerException("supplier is null");
} else if (copierForChildValue == null) {
throw new NullPointerException("ttl copier for child value is null");
} else if (copierForCopy == null) {
throw new NullPointerException("ttl copier for copy value is null");
} else {
return new SuppliedTransmittableThreadLocal<>(supplier, copierForChildValue, copierForCopy);
}
}
// 复制父值
public T copy(T parentValue) {
return parentValue;
}
// 任务执行前的钩子方法,子类可重写
protected void beforeExecute() {
}
// 任务执行后的钩子方法,子类可重写
protected void afterExecute() {
}
// 获取值,必要时添加到 holder
public final T get() {
T value = super.get();
if (this.disableIgnoreNullValueSemantics || null != value) {
this.addThisToHolder();
}
return value;
}
// 设置值,必要时添加到 holder
public final void set(T value) {
if (!this.disableIgnoreNullValueSemantics && null == value) {
this.remove();
} else {
super.set(value);
this.addThisToHolder();
}
}
// 移除值,同时从 holder 中移除
public final void remove() {
this.removeThisFromHolder();
super.remove();
}
private void superRemove() {
super.remove();
}
// 复制当前值
private T copyValue() {
return this.copy(this.get());
}
// 将当前对象添加到 holder
private void addThisToHolder() {
if (!((WeakHashMap)holder.get()).containsKey(this)) {
((WeakHashMap)holder.get()).put(this, null);
}
}
// 将当前对象从 holder 中移除
private void removeThisFromHolder() {
((WeakHashMap)holder.get()).remove(this);
}
// 执行回调方法
private static void doExecuteCallback(boolean isBefore) {
WeakHashMap<TransmittableThreadLocal<Object>, ?> ttlInstances = new WeakHashMap<>((Map)holder.get());
for (TransmittableThreadLocal<Object> threadLocal : ttlInstances.keySet()) {
try {
if (isBefore) {
threadLocal.beforeExecute();
} else {
threadLocal.afterExecute();
}
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t, t);
}
}
}
}
// 打印调试信息
static void dump(@Nullable String title) {
if (title != null && title.length() > 0) {
System.out.printf("Start TransmittableThreadLocal[%s] Dump...%n", title);
} else {
System.out.println("Start TransmittableThreadLocal Dump...");
}
for (TransmittableThreadLocal<Object> threadLocal : ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet()) {
System.out.println(threadLocal.get());
}
System.out.println("TransmittableThreadLocal Dump end!");
}
static void dump() {
dump(null);
}
// Transmitter 类,负责捕获和恢复上下文
public static class Transmitter {
private static volatile WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> threadLocalHolder = new WeakHashMap<>();
private static final Object threadLocalHolderUpdateLock = new Object();
private static final Object threadLocalClearMark = new Object();
private static final TtlCopier<Object> shadowCopier = parentValue -> parentValue;
// 捕获当前线程的上下文信息
@NonNull
public static Object capture() {
return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
// 捕获 TTL 值
private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>();
for (TransmittableThreadLocal<Object> threadLocal : ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
// 捕获 ThreadLocal 值
private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<>();
for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
ThreadLocal<Object> threadLocal = entry.getKey();
TtlCopier<Object> copier = entry.getValue();
threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
}
return threadLocal2Value;
}
// 重新设置捕获的上下文信息
@NonNull
public static Object replay(@NonNull Object captured) {
Snapshot capturedSnapshot = (Snapshot) captured;
return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}
// 重新设置 TTL 值
@NonNull
private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>();
Iterator<TransmittableThreadLocal<Object>> iterator = ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet().iterator();
while (iterator.hasNext()) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
backup.put(threadLocal, threadLocal.get());
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
setTtlValuesTo(captured);
TransmittableThreadLocal.doExecuteCallback(true);
return backup;
}
// 重新设置 ThreadLocal 值
private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
HashMap<ThreadLocal<Object>, Object> backup = new HashMap<>();
for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
ThreadLocal<Object> threadLocal = entry.getKey();
backup.put(threadLocal, threadLocal.get());
Object value = entry.getValue();
if (value == threadLocalClearMark) {
threadLocal.remove();
} else {
threadLocal.set(value);
}
}
return backup;
}
// 清除上下文信息
@NonNull
public static Object clear() {
HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>();
HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<>();
for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
ThreadLocal<Object> threadLocal = entry.getKey();
threadLocal2Value.put(threadLocal, threadLocalClearMark);
}
return replay(new Snapshot(ttl2Value, threadLocal2Value));
}
// 恢复上下文信息
public static void restore(@NonNull Object backup) {
Snapshot backupSnapshot = (Snapshot) backup;
restoreTtlValues(backupSnapshot.ttl2Value);
restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
// 恢复 TTL 值
private static void restoreTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
TransmittableThreadLocal.doExecuteCallback(false);
Iterator<TransmittableThreadLocal<Object>> iterator = ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet().iterator();
while (iterator.hasNext()) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
if (!backup.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
setTtlValuesTo(backup);
}
// 设置 TTL 值
private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
TransmittableThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
// 恢复 ThreadLocal 值
private static void restoreThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
ThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
// 使用捕获的上下文信息执行 Supplier
public static <R> R runSupplierWithCaptured(@NonNull Object captured, @NonNull Supplier<R> bizLogic) {
Object backup = replay(captured);
try {
return bizLogic.get();
} finally {
restore(backup);
}
}
// 清除上下文信息后执行 Supplier
public static <R> R runSupplierWithClear(@NonNull Supplier<R> bizLogic) {
Object backup = clear();
try {
return bizLogic.get();
} finally {
restore(backup);
}
}
// 使用捕获的上下文信息执行 Callable
public static <R> R runCallableWithCaptured(@NonNull Object captured, @NonNull Callable<R> bizLogic) throws Exception {
Object backup = replay(captured);
try {
return bizLogic.call();
} finally {
restore(backup);
}
}
// 清除上下文信息后执行 Callable
public static <R> R runCallableWithClear(@NonNull Callable<R> bizLogic) throws Exception {
Object backup = clear();
try {
return bizLogic.call();
} finally {
restore(backup);
}
}
// 注册 ThreadLocal
public static <T> boolean registerThreadLocal(@NonNull ThreadLocal<T> threadLocal, @NonNull TtlCopier<T> copier) {
return registerThreadLocal(threadLocal, copier, false);
}
// 注册带有 ShadowCopier 的 ThreadLocal
public static <T> boolean registerThreadLocalWithShadowCopier(@NonNull ThreadLocal<T> threadLocal) {
return registerThreadLocal(threadLocal, shadowCopier, false);
}
// 注册 ThreadLocal,带有复制器和是否强制注册的选项
public static <T> boolean registerThreadLocal(@NonNull ThreadLocal<T> threadLocal, @NonNull TtlCopier<T> copier, boolean force) {
if (threadLocal instanceof TransmittableThreadLocal) {
TransmittableThreadLocal.logger.warning("register a TransmittableThreadLocal instance, this is unnecessary!");
return true;
} else {
synchronized (threadLocalHolderUpdateLock) {
if (!force && threadLocalHolder.containsKey(threadLocal)) {
return false;
} else {
WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> newHolder = new WeakHashMap<>(threadLocalHolder);
newHolder.put(threadLocal, copier);
threadLocalHolder = newHolder;
return true;
}
}
}
}
// 注册带有 ShadowCopier 的 ThreadLocal,并带有是否强制注册的选项
public static <T> boolean registerThreadLocalWithShadowCopier(@NonNull ThreadLocal<T> threadLocal, boolean force) {
return registerThreadLocal(threadLocal, shadowCopier, force);
}
// 取消注册 ThreadLocal
public static <T> boolean unregisterThreadLocal(@NonNull ThreadLocal<T> threadLocal) {
if (threadLocal instanceof TransmittableThreadLocal) {
TransmittableThreadLocal.logger.warning("unregister a TransmittableThreadLocal instance, this is unnecessary!");
return true;
} else {
synchronized (threadLocalHolderUpdateLock) {
if (!threadLocalHolder.containsKey(threadLocal)) {
return false;
} else {
WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> newHolder = new WeakHashMap<>(threadLocalHolder);
newHolder.remove(threadLocal);
threadLocalHolder = newHolder;
return true;
}
}
}
}
private Transmitter() {
throw new InstantiationError("Must not instantiate this class");
}
// Snapshot 类,用于存储捕获的上下文信息
private static class Snapshot {
final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
final HashMap<ThreadLocal<Object>, Object> threadLocal2Value;
private Snapshot(HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
this.ttl2Value = ttl2Value;
this.threadLocal2Value = threadLocal2Value;
}
}
}
// SuppliedTransmittableThreadLocal 类,带有初始值和复制器的 TransmittableThreadLocal 实现
private static final class SuppliedTransmittableThreadLocal<T> extends TransmittableThreadLocal<T> {
private final Supplier<? extends T> supplier;
private final TtlCopier<T> copierForChildValue;
private final TtlCopier<T> copierForCopy;
SuppliedTransmittableThreadLocal(Supplier<? extends T> supplier, TtlCopier<T> copierForChildValue, TtlCopier<T> copierForCopy) {
if (supplier == null) {
throw new NullPointerException("supplier is null");
} else {
this.supplier = supplier;
this.copierForChildValue = copierForChildValue;
this.copierForCopy = copierForCopy;
}
}
protected T initialValue() {
return this.supplier.get();
}
protected T childValue(T parentValue) {
return this.copierForChildValue != null ? this.copierForChildValue.copy(parentValue) : super.childValue(parentValue);
}
public T copy(T parentValue) {
return this.copierForCopy != null ? this.copierForCopy.copy(parentValue) : super.copy(parentValue);
}
}
}
4. 使用框架提供的上下文传递功能
许多现代框架提供了对线程局部变量和上下文传递的支持。例如,Spring 框架提供了 @Async
注解,可以在异步方法中自动传递上下文信息。
示例(Spring @Async
)
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
@Service
public class AsyncService {
@Async
public void asyncMethod(String value) {
System.out.println("异步方法执行,传递的值为: " + value);
}
}
配置类:
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
@Configuration
@EnableAsync
public class AsyncConfig {
}
调用方法:
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@Component
public class AsyncCaller {
@Autowired
private AsyncService asyncService;
public void callAsyncMethod() {
asyncService.asyncMethod("测试值");
}
}
总结
在多线程编程中,ThreadLocal
和 `InheritableThreadLocal 解决了线程局部变量的问题,但在复杂的线程池和异步执行场景下,这些工具存在局限性。
TransmittableThreadLocal通过捕获和恢复上下文信息,并包装线程池和任务,确保上下文的正确传递和清理,是一种有效的解决方案。此外,现代框架提供的上下文传递功能(如 Spring 的
@Async`)也是解决上下文传递问题的有效方式。
选择适合的工具和方法,可以更好地管理上下文数据,确保系统的稳定性和可靠性。通过这些方式,我们可以在复杂的多线程环境中有效地传递和管理上下文信息,避免数据污染和内存泄漏问题。