一、背景
比如TraceId、UserContext、RequestContext需要在线程池进行传递和清理
二、实现
- 统一的ThreadLocal垃圾回收
public class ThreadLocalCleaner {
private static final List<ThreadLocal<?>> CONTAINER = Lists.newArrayList();
public static void register(ThreadLocal<?>... threadLocal) {
if(ArrayUtils.isNotEmpty(threadLocal)) {
CONTAINER.addAll(Arrays.asList(threadLocal));
}
}
public static void register(List<ThreadLocal<?>> threadLocal) {
if(!CollectionUtils.isEmpty(threadLocal)) {
CONTAINER.addAll(threadLocal);
}
}
public static void clear() {
CONTAINER.forEach(ThreadLocal::remove);
}
}
- 统一ThreadLocal存储
@SuppressWarnings("unchecked")
public final class ThreadLocalManager {
private ThreadLocalManager() { }
private static final List<ThreadLocal<?>> THREAD_LOCALS = Lists.newArrayList();
private static final ThreadLocal<Map<String, Object>> CONTAINER = ThreadLocal.withInitial(HashMap::new);
static {
ThreadLocalCleaner.register(CONTAINER);
ThreadLocalCleaner.register(THREAD_LOCALS);
}
/**
* @return threadLocal中的全部值
*/
public static Map<String, Object> getAll() {
return new HashMap<>(CONTAINER.get());
}
/**
* 将ThreadLocal纳入ThreadLocalManager管理
*
* @param threadLocalContext threadLocal对象
*/
public static void register(ThreadLocal<?> threadLocalContext) {
THREAD_LOCALS.add(threadLocalContext);
}
/**
* 获取threadLocal键值集合
*
* @return threadLocal键值集合
*/
public static Map<ThreadLocal<?>, ?> getCopyThreadLocalMap() {
return THREAD_LOCALS.stream().collect(Collectors.toMap(o -> o, ThreadLocal::get));
}
/**
* 设置一个值到ThreadLocal
* 注意:最好加上业务前缀保证key不会重复!
*
* @param key 键
* @param value 值
* @param <T> 值的类型
* @return 被放入的值
* @see Map#put(Object, Object)
*/
public static <T> T put(String key, T value) {
CONTAINER.get().put(key, value);
return value;
}
/**
* 设置所有制到ThreadLocal
* 注意:最好加上业务前缀保证key不会重复!
*
* @param kv 键值
*/
public static void putAll(Map<String, Object> kv) {
CONTAINER.get().putAll(kv);
}
/**
* 删除参数对应的值
*
* @param key
* @see Map#remove(Object)
*/
public static void remove(String key) {
CONTAINER.get().remove(key);
}
/**
* 清空ThreadLocal
*
* @see Map#clear()
*/
public static void clear() {
CONTAINER.remove();
}
/**
* 从ThreadLocal中获取值
*
* @param key 键
* @param <T> 值泛型
* @return 值, 不存在则返回null, 如果类型与泛型不一致, 可能抛出{@link ClassCastException}
* @see Map#get(Object)
* @see ClassCastException
*/
public static <T> T get(String key) {
return ((T) CONTAINER.get().get(key));
}
/**
* 从ThreadLocal中获取值,并指定一个当值不存在的提供者
*
* @see Supplier
*/
public static <T> T get(String key, Supplier<T> supplierOnNull) {
return ((T) CONTAINER.get().computeIfAbsent(key, k -> supplierOnNull.get()));
}
/**
* 获取一个值后然后删除掉
*
* @param key 键
* @param <T> 值类型
* @return 值, 不存在则返回null
* @see this#get(String)
* @see this#remove(String)
*/
public static <T> T getAndRemove(String key) {
try {
return get(key);
} finally {
remove(key);
}
}
}
- 扩展Runnable和Callable
public class TransferCallable<T> implements Callable<T> {
private Callable<T> callable;
private Map<String, String> mdcContext;
private Map<String, Object> threadLocalContext;
private Map<ThreadLocal<?>, ?> threadLocalMap;
public TransferCallable(Callable<T> callable, Map<String, String> mdcContext,
Map<String, Object> threadLocalContext, Map<ThreadLocal<?>, ?> threadLocalMap) {
this.callable = callable;
this.mdcContext = mdcContext;
this.threadLocalContext = threadLocalContext;
this.threadLocalMap = threadLocalMap;
}
@Override
@SuppressWarnings("all")
public T call() throws Exception {
Map<String, String> previous = MDC.getCopyOfContextMap();
if (mdcContext == null) {
MDC.clear();
} else {
MDC.setContextMap(mdcContext);
}
for (Map.Entry<ThreadLocal<?>, ?> entry : threadLocalMap.entrySet()) {
ThreadLocal key = entry.getKey();
key.set(entry.getValue());
}
ThreadLocalManager.putAll(threadLocalContext);
try {
return callable.call();
} finally {
if (previous == null) {
MDC.clear();
} else {
MDC.setContextMap(previous);
}
ThreadLocalManager.clear();
threadLocalMap.keySet().forEach(ThreadLocal::remove);
}
}
}
public class TransferRunnable implements Runnable{
private Runnable runnable;
private Map<String, String> mdcContext;
private Map<String, Object> threadLocalContext;
private Map<ThreadLocal<?>, ?> threadLocalMap;
public TransferRunnable(Runnable runnable, Map<String, String> mdcContext,
Map<String, Object> threadLocalContext, Map<ThreadLocal<?>, ?> threadLocalMap) {
this.runnable = runnable;
this.mdcContext = mdcContext;
this.threadLocalContext = threadLocalContext;
this.threadLocalMap = threadLocalMap;
}
@Override
@SuppressWarnings("all")
public void run() {
Map<String, String> previous = MDC.getCopyOfContextMap();
if (mdcContext == null) {
MDC.clear();
} else {
MDC.setContextMap(mdcContext);
}
for (Map.Entry<ThreadLocal<?>, ?> entry : threadLocalMap.entrySet()) {
ThreadLocal key = entry.getKey();
key.set(entry.getValue());
}
ThreadLocalManager.putAll(threadLocalContext);
try {
runnable.run();
} finally {
if (previous == null) {
MDC.clear();
} else {
MDC.setContextMap(previous);
}
ThreadLocalManager.clear();
threadLocalMap.keySet().forEach(ThreadLocal::remove);
}
}
}
- ThreadPoolTaskExecutor的task包装
public class TransferTaskDecorator implements TaskDecorator {
@Override
public Runnable decorate(Runnable runnable) {
return new TransferRunnable(runnable, MDC.getCopyOfContextMap(), ThreadLocalManager.getAll(), ThreadLocalManager.getCopyThreadLocalMap());
}
}
- ThreadLocal自动清理
public class ThreadLocalCleanInterceptor extends HandlerInterceptorAdapter {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
return super.preHandle(request, response, handler);
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
ThreadLocalCleaner.clear();
super.afterCompletion(request, response, handler, ex);
}
}
- 异步线程池配置
@Bean("asyncExecutor")
public Executor asyncExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4);
executor.setMaxPoolSize(8);
executor.setQueueCapacity(200);
executor.setThreadNamePrefix("asyncExecutor-");
executor.setTaskDecorator(new TransferTaskDecorator());
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
executor.setWaitForTasksToCompleteOnShutdown(true);
executor.setAwaitTerminationSeconds(30);
return executor;
}
- 使用方式
public class UserContext {
private static final String KEY = "USR_ACCOUNT";
public static UserInfo getUser() {
return ThreadLocalManager.get(KEY);
}
public static void setUser(UserInfo userInfo) {
ThreadLocalManager.put(KEY, userInfo);
}
public static void remove() {
ThreadLocalManager.remove(KEY);
}
}
三、注意事项
- 定时任务不支持ThreadLocal自动回收
- 可以对@Scheduled进行切面然后finally中回收
- 更加方便的方式可以使用阿里的
transmittable-thread-local
的插件 - 也可以参考开源调用链路监测用agent对字节码进行编辑