TransmittableThreadLocal详解,源码分析,一文带你掌握核心逻辑

TransmittableThreadLocal详解,源码分析,一文带你掌握核心逻辑

这篇文章主要想向大家介绍TransmittableThreadLocal的基本原理,具体应用其实就是在使用线程池的时候支持threadLocal的子线程传递

大家可能会说Jdk自带的InheritableThreadLocal不是已经实现了ThreadLocal的传递吗?为什么还需要TransmittableThreadLocal。
如果大家有看过InheritableThreadLocal的源代码就会发现,InheritableThreadLocal的传递是发生在新建线程的时候,但是在我们项目中基本都是用线程池来使用多线程,因为线程的创建销毁都会有资源消耗,当然也有可能发生oom,具体就不多说了。
TransmittableThreadLocal就解决了线程池维度下的ThreadLocal传递,大家可以浅浅思考一下怎么改变一下,就可以实现线程复用的情况下也可以传递ThreadLocal。
其实就是把传递ThreadLocal的逻辑放在run方法的前面就解决了,思想就是这么简单,但是却很巧妙,使用装饰者模式把Runnable包了一层

源码分析

TransmittableThreadLocal

首先很重要的是TransmittableThreadLocal里面保存了一个静态变量holder,大家得先理解好holder
holder 变量是一个InheritableThreadLocal, 他是一个map但是一直都是当作Set在用,value一直是空

The value of holder is type WeakHashMap<TransmittableThreadLocal, ?>, but it is used as Set (aka. do NOT use about value, always null).

每次使用holder变量都会带着.get() ,意味着每次获取到的WeakHashMap都是线程自己的,这是我之前一直不理解的点,记住,每次使用holder都会带上.get(),而不是真正全局使用一个WeakHashMap

接下来来分析一下set()方法

  1. 首先会将当前的value设置到Thread中的ThreadLocalMap

  2. 然后将当前的TransmittableThreadLocal放入当前线程的map中(给未来打快照使用)

public final void set(T value) {
  // 首先先使用ThreadLocal的特性,将当前的value设置到Thread中的ThreadLocalMap,保证ThreadLocal的特性被保留
  super.set(value);
  // may set null to remove value
  // 如果是value是null,就从map中删除
  if (null == value) removeValue();
  // 这一步本质就是将当前的TransmittableThreadLocal放入当前线程的map中,以保存父线程的使用过的TransmittableThreadLocal
  else addValue();
}

private void addValue() {
  if (!holder.get().containsKey(this)) {
    // 将当前的TransmittableThreadLocal放入当前线程的map中
    // 这里value一直是空
    holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
  }
}

get()方法没有什么特殊的大家看看就行

public final T get() {
  T value = super.get();
  if (null != value) addValue();
  return value;
}

TtlRunnable

重要!!!

我认为整个TransmittableThreadLocal的核心就是使用装饰模式,将整个Runnable 包装了一层,实现了当线程复用的情况也可以将父线程继承到子线程的能力

首先在使用TtlRunnable.get(runnable),会将Runnable包装一层,此时的调用方就是父线程,在方法中会调用capture(),获取当前父线程的快照

打快照
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
  if (null == runnable) return null;

  if (runnable instanceof TtlEnhanced) {
    // avoid redundant decoration, and ensure idempotency
    if (idempotent) return (TtlRunnable) runnable;
    else throw new IllegalStateException("Already TtlRunnable!");
  }
  return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

// 构造函数
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
  // capture()此方法是核心
  this.capturedRef = new AtomicReference<Object>(capture());
  this.runnable = runnable;
  this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}

这里主要是将父线程所有的threadLocal打一个快照

第一个是TransmittableThreadLocal

第二个是threadLocalHolder,这个threadLocalHolder的作用是对于在项目中使用了ThreadLocal,但是却无法替换为TransmittableThreadLocal的情况,可以使用Transmitter提供的注册方法,将项目中的threadLocal注册到它的threadLocalHolder中,后面进行capture等操作时holder和threadLocalHolder都会进行处理使用

public static Object capture() {
  return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
  
  WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
  // 核心:获取之前保存的当前线程使用过的TransmittableThreadLocal,然后将其放在一个map中,key为TransmittableThreadLocal,value为具体的值
  // 这样就形成了在父线程调用TtlRunnable.get(runnable)父线程使用TransmittableThreadLocal的快照
  for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
    ttl2Value.put(threadLocal, threadLocal.copyValue());
  }
  return ttl2Value;
}
// 这个是处理threadLocalHolder
private static WeakHashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
  final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value = new WeakHashMap<ThreadLocal<Object>, Object>();
  for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
    final ThreadLocal<Object> threadLocal = entry.getKey();
    final TtlCopier<Object> copier = entry.getValue();

    threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
  }
  return threadLocal2Value;
}
run方法包装

接下来就是将Run方法包装了一层,注意调用run方法的一定是子线程

  1. 获取当前之前创建的父线程ThreadLocal快照

  2. 重放快照到子线程的Thread中

  3. 执行run方法

  4. 恢复子线程的快照

@Override
public void run() {
  // 获取当前之前创建的父线程ThreadLocal快照
  // 这里快照不应该为空
  Object captured = capturedRef.get();
  if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
    throw new IllegalStateException("TTL value reference is released after run!");
  }

  Object backup = replay(captured);
  try {
    runnable.run();
  } finally {
    restore(backup);
  }
}
重放快照
// 将父线程的ThreadLocal回放到子线程中
@NonNull
public static Object replay(@NonNull Object captured) {
  final Snapshot capturedSnapshot = (Snapshot) captured;
  return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

// 回放快照
// 备份子线程threadLocal
@NonNull
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
  WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
	// 将子线程自带的threadLocal给备份起来,其实也是打个快照而已
  for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
    TransmittableThreadLocal<Object> threadLocal = iterator.next();

    // backup
    // 将子线程自带的threadLocal给备份起来,其实也是打个快照而已
    backup.put(threadLocal, threadLocal.get());

    // 如果父线程快照中不存在当前ThreadLocal 就删掉这个threadLocal
    // 因为这一步就是为了把父线程的threadLocal放进子线程中
    // clear the TTL values that is not in captured
    // avoid the extra TTL values after replay when run task
    if (!captured.containsKey(threadLocal)) {
      iterator.remove();
      threadLocal.superRemove();
    }
  }

  // set TTL values to captured
  // 将父线程的threadLocal快照放到子线程中
  setTtlValuesTo(captured);

  // call beforeExecute callback
  doExecuteCallback(true);

  return backup;
}

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
  // 将父线程的threadLocal快照放到子线程中
  for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
    TransmittableThreadLocal<Object> threadLocal = entry.getKey();
    threadLocal.set(entry.getValue());
  }
}

private static WeakHashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> captured) {
  final WeakHashMap<ThreadLocal<Object>, Object> backup = new WeakHashMap<ThreadLocal<Object>, Object>();

  for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
    final ThreadLocal<Object> threadLocal = entry.getKey();
    backup.put(threadLocal, threadLocal.get());

    final Object value = entry.getValue();
    if (value == threadLocalClearMark) threadLocal.remove();
    else threadLocal.set(value);
  }

  return backup;
}
restore

恢复子线程的threadLocal现场

仔细看就会发现是replay的反向操作

public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
  // call afterExecute callback
  doExecuteCallback(false);
  // 查询子线程使用的TransmittableThreadLocal, 然后遍历它
  for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
    TransmittableThreadLocal<Object> threadLocal = iterator.next();

    // clear the TTL values that is not in backup
    // avoid the extra TTL values after restore
    // 原来不存在就删掉这个threadLocal
    if (!backup.containsKey(threadLocal)) {
      iterator.remove();
      threadLocal.superRemove();
    }
  }

  // restore TTL values
  setTtlValuesTo(backup);
}
private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
  for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
    // 遍历一下之前子线程的快照是存在的,把它恢复了
    TransmittableThreadLocal<Object> threadLocal = entry.getKey();
    threadLocal.set(entry.getValue());
  }
}
private static void restoreThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> backup) {
    for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

总结

到这里分析就结束了

整体我认为这个设计的还是很巧妙的,解决了InheritableThreadLocal在线程复用(线程池的情况无法使用的问题)

  1. 首先使用了holder这样一个ThreadLocal,记录了每一个线程使用了哪些threadLocal,到时候可以直接将这个线程所有的thread以及value遍历出来

  2. 只用TtlRunnable把Runnable包装了一层,在调用.get时就把父线程打了个快照

  3. 把Runnable的run方法包装了一层,让线程开始执行之前回放父线程的threadLocal,执行结束后恢复子线程原来就有的threadLocal

    如果大家喜欢这篇文章的话,可以点赞收藏一下,这是对我最大的支持

    联系方式: xianchaolin@126.com

  • 12
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值