TransmittableThreadLocal 源码分析

一. TransmittableThreadLocal 简介

传统的jdk的ThreadLocal只能解决线程或者父,子线程之间的数据传递。 而 TransmittableThreadLocal 是为了解决 线程池里面的线程之间传递ThreadLocal。

二. 实例带入

public static void main(String[] args) throws Exception {
        ThreadLocal<String> THREAD_LOCAL = new TransmittableThreadLocal<>();
        THREAD_LOCAL.set("123");
        ExecutorService THREAD_POOL = TtlExecutors.getTtlExecutorService(Executors.newSingleThreadExecutor());
        THREAD_POOL.execute(() -> System.err.println(THREAD_LOCAL.get()));
    }


// 打印结果
123

三. 自己动手来实现一个简单的TransmittableThreadLocal 了解其实现原理

为了简单搞清楚实现原理,我们先手写一个简单实现(注意是简单实现)

一定要动手, 后面在分析源代码。主要是三个类:

1. TransmittableThreadLocal

public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> {
    // 全局静态存储 , 保存多个客户端TransmittableThreadLocal ,  key:TransmittableThreadLocal对象的this引用, value: null
    public static final WeakHashMap<TransmittableThreadLocal<Object>, ?> mapper = new WeakHashMap<>();

    @SuppressWarnings("unchecked")
    @Override
    public void set(T value) {
        // 子父线程ThreadLocal 还是交给jdk 的实现
        super.set(value);
        // 将当前 this 放入mapper存起来
        mapper.put((TransmittableThreadLocal<Object>) this, null);
    }
}

继承jdk的InheritableThreadLocal, 保留了子父线程支持传递ThreadLocal的特性。这里有一个全局静态Map,客户端构造几个TransmittableThreadLocal对象,Map里面就保存了key,注意value为null。

2. ExecutorServiceTtlWrapper

public class ExecutorServiceTtlWrapper implements ExecutorService {
   //原生的 jdk 线程池
    private final ExecutorService executorService;

    public ExecutorServiceTtlWrapper(ExecutorService executorService) {
        this.executorService = executorService;
    }

    @Override
    public void execute(Runnable command) {
      // 提交到原生线程池 , 但是这里对Runnable又进行了包装
        executorService.execute(TtlRunnable.get(command));
    }
// 下面还有很多方法需要重写.............这里不坐探讨 ,空实现就行
}

包装器模式,包装了jdk原生的线程池,然后又将Runnable 任务包装成TtlRunnable对象提交。

3. TtlRunnable

public class TtlRunnable implements Runnable {
    private Runnable runnable;
    private HashMap<TransmittableThreadLocal<Object>, Object> copy = new HashMap<TransmittableThreadLocal<Object>, Object>();

    public static TtlRunnable get(Runnable runnable) {
        TtlRunnable r = new TtlRunnable();
        r.runnable = runnable;
        // 把值拷贝到 copy 里面
        for (TransmittableThreadLocal<Object> threadLocal : TransmittableThreadLocal.mapper.keySet()) {
            // 把值和threadLocal 拷贝到map
            r.copy.put(threadLocal, threadLocal.get());
        }
        return r;
    }
    @Override
    @SuppressWarnings("unchecked")
    public void run() {
        // 把值设置到 当前线程的 threadlocal里面
        for (TransmittableThreadLocal<Object> t : copy.keySet()) {
            // 获取 拷贝的值
            Object value = t.get();
            // 获取主线的ThreadLocal
            TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) TransmittableThreadLocal.mapper
                    .get(t);
            // 重新把值设置 到 子线程里面
            threadLocal.set(value);
        }
        // 执行 runable
        runnable.run();
    }
}

几个重点

  • TtlRunnable为Runnable的包装类,可认为这就是委派模式。
  • TtlRunnable里面的copy 成员变量Map存储的是:key >
    客户端TransmittableThreadLocal对象的this引用 , value
    >
    客户端TransmittableThreadLocal 最后一次set的值的(也就是客户端线程的ThreadLocal的值)。
  • 当线程池提交任务后会执行上述run方法。我们先把copy的值拷贝到当前线程池执行线程的ThreadLocal里面,对应
    threadLocal.set(value); 这行关键代码。 然后执行runable真正的run。

是不是很精巧,在执行run之前将保存的客户端的ThreadLocal的值都set 到当前线程池执行的线程,这样客户端就可以通过ThreadLocal的get 来取得值。

相信上面代码应该你能看懂,其实源码大致也就这实现思想,我们还是来分析下。

四. TransmittableThreadLocal 源码

TransmittableThreadLocal内部也有个Map , 但是是放在ThreadLocal里面的,保证了每个线程都有一份,不像我们之前定义的全局静态变量有线程安全问题。

 private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };

当我们调set值的时候会向holder设置当前的TransmittableThreadLocal对象this引用。

com.alibaba.ttl.TransmittableThreadLocal#set 方法源码

 public final void set(T value) {
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            super.set(value);
            addThisToHolder();
        }
    }
   private void addThisToHolder() {
        WeakHashMap<TransmittableThreadLocal<Object>, ?> w =  holder.get() ;
        if (!w.containsKey(this)) {
            w.put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
        }
    }

把WeakHashMap当成set用,key存储客户端的多个TransmittableThreadLocal ,值 存 null ,不用理会。

com.alibaba.ttl.threadpool.TtlExecutors#getTtlExecutorService 方法

  public static ExecutorService getTtlExecutorService( ExecutorService executorService) {
        if (TtlAgent.isTtlAgentLoaded() || executorService == null || executorService instanceof TtlEnhanced) {
            return executorService;
        }
        return new ExecutorServiceTtlWrapper(executorService, true);
    }

这个TtlExecutors就是个工厂类,构造了jdk线程池包装器: ExecutorServiceTtlWrapper

ExecutorServiceTtlWrapper构造很简单

class ExecutorServiceTtlWrapper extends ExecutorTtlWrapper implements ExecutorService, TtlEnhanced {
    // jdk线程池引用   
    private final ExecutorService executorService;

    ExecutorServiceTtlWrapper( ExecutorService executorService, boolean idempotent) {
        super(executorService, idempotent);
        this.executorService = executorService;
    }
}

ExecutorServiceTtlWrapper持有jdk线程池ExecutorService的引用,典型的包装,代理。

super(executorService, idempotent); 这行代码就是构造下面这个父类(ExecutorTtlWrapper )

在到我们的提交任务execute,ExecutorServiceTtlWrapper 没有execute的实现,而是在它的父类ExecutorTtlWrapper 里面

class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
    // jdk线程池引用
    private final Executor executor;
    protected final boolean idempotent;

    ExecutorTtlWrapper( Executor executor, boolean idempotent) {
        this.executor = executor;
        this.idempotent = idempotent;
    }
    @Override
    public void execute( Runnable command) {
       // 提交的是 Runnable 的包装 TtlRunnable类
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }
}

又对Runnable包装了一层,TtlRunnable

public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    private final AtomicReference<Object> capturedRef;
    private final Runnable runnable;
    private final boolean releaseTtlValueReferenceAfterRun;

    private TtlRunnable( Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }
   public static TtlRunnable get(  Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
        return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
    }
}

这里有个很重要的AtomicReference原子性对象。 AtomicReference 类似 AtomicInteger ,都是利用CAS实现的原子安全操作。只是前者是针对对象,后者是针对Integer。

capture() ,在TransmittableThreadLocal类里面的Transmitter工具下

public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {

     public static class Transmitter {
        public static Object capture() {
            return new Snapshot(captureTtlValues(), captureThreadLocalValues());
        }

        private static class Snapshot {
            final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
            final HashMap<ThreadLocal<Object>, Object> threadLocal2Value;

            private Snapshot(HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
                this.ttl2Value = ttl2Value;
                this.threadLocal2Value = threadLocal2Value;
            }
        }
    }
}

我们发现AtomicReference里面其实就是Snapshot对象。Snapshot里面有两个HashMap , 这里我只讲一个ttl2Value ,这个是不是就是我们自己实现的 TtlRunnable 类里面的copy 一样。

captureTtlValues方法(captureThreadLocalValues方法我不分析)

 private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
            HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
            WeakHashMap<TransmittableThreadLocal<Object>, ?> w =  holder.get() ;
            for (TransmittableThreadLocal<Object> threadLocal : w.keySet()) {
                ttl2Value.put(threadLocal, threadLocal.copyValue());
            }
            return ttl2Value;
        }

holder为TransmittableThreadLocal的holder,开先我们讲了存的是多个客户端的TransmittableThreadLocal对象this , threadLocal.copyValue() 就是取到客户端线程ThreadLocal的值。

假如你客户端是这样设置:
 ThreadLocal<String> THREAD_LOCAL1 = new TransmittableThreadLocal<>(); //  com.alibaba.ttl.TransmittableThreadLocal@33909752
  THREAD_LOCAL1.set("海尔兄弟");
 ThreadLocal<String> THREAD_LOCAL2 = new TransmittableThreadLocal<>(); //  com.alibaba.ttl.TransmittableThreadLocal@44909754
  THREAD_LOCAL2.set("海贼王");
则HashMap存的是 
{com.alibaba.ttl.TransmittableThreadLocal@33909752:"海尔兄弟",com.alibaba.ttl.TransmittableThreadLocal@44909754:"海贼王"}

至此我们的TtlRunnable已经构造完,并提交到了线程池,等待线程池调度运行TtlRunnable的run方法。

public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    private final AtomicReference<Object> capturedRef;
    @Override
    public void run() {
        final Object captured = capturedRef.get();
        final Object backup = replay(captured);
        try {
            runnable.run();
        } finally {
            restore(backup);
        }
    }
}

replay 为前面Transmitter工具里面的方法

public static Object replay( Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;
            return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }
        private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues( HashMap<TransmittableThreadLocal<Object>, Object> captured) {
            HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<TransmittableThreadLocal<Object>, Object>();
            WeakHashMap<TransmittableThreadLocal<Object>, ?> w =  holder.get() ;
            for (final Iterator<TransmittableThreadLocal<Object>> iterator = w.keySet().iterator(); iterator.hasNext(); ) {
               // 为客户端设置的TransmittableThreadLocal
                TransmittableThreadLocal<Object> threadLocal = iterator.next();
                // 将threadLocal对象this和对象的threadLocal的值放到复制返回的HashMap里面
                backup.put(threadLocal, threadLocal.get());
            }
            // 将客户端线程的ThreadLocal拷贝到当前线程池线程的ThreadLocal里
            setTtlValuesTo(captured);

            // call beforeExecute callback
            doExecuteCallback(true);

            return backup;
        }
       private static void setTtlValuesTo( HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
            for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
                TransmittableThreadLocal<Object> threadLocal = entry.getKey();
                // 这一行就是拷贝到当前线程池线程的ThreadLocal
                threadLocal.set(entry.getValue());
            }
        }

到此我们的线程池的执行线程的ThreadLocal 里面就有了客户端线程的ThreadLocal了。接着来我们的run方法的 restore 。

  private static void restoreTtlValues( HashMap<TransmittableThreadLocal<Object>, Object> backup) {
            // call afterExecute callback
            doExecuteCallback(false);
            WeakHashMap<TransmittableThreadLocal<Object>, ?> w =  holder.get() ;
            for (final Iterator<TransmittableThreadLocal<Object>> iterator = w.keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();
                // clear the TTL values that is not in backup
                // avoid the extra TTL values after restore
                if (!backup.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }
            // restore TTL values
            setTtlValuesTo(backup);
        }

这是业务run执行完后执行的ThreadLocal的remove操作。不能常驻在线程池线程的ThreadLocal里面。

五. TransmittableThreadLocal 里面很有意思的Agent

Agent我后续会在博客详细介绍。作用就是在main方法启动前修改程序的class字节码,来达到不修改一行代码也能改变业务。

TransmittableThreadLocal如果用Agent 方式,就不会改业务使用TransmittableThreadLocal 。

启用Agent功能,需要在Java的启动参数添加:-javaagent:path/to/transmittable-thread-local-x.yzx.jar。

TransmittableThreadLocal Agent的入口类是TtlAgent类,你现在知道的是启动main方法前jvm会先调用下面的premain方法

public final class TtlAgent {
 public static void premain(final String agentArgs,  final Instrumentation inst) {
        kvs = splitCommaColonStringToKV(agentArgs);

        Logger.setLoggerImplType(getLogImplTypeFromAgentArgs(kvs));
        final Logger logger = Logger.getLogger(TtlAgent.class);

        try {
            logger.info("[TtlAgent.premain] begin, agentArgs: " + agentArgs + ", Instrumentation: " + inst);
            final boolean disableInheritableForThreadPool = isDisableInheritableForThreadPool();
            //要改字节码的Transformlet
            final List<JavassistTransformlet> transformletList = new ArrayList<JavassistTransformlet>();
            transformletList.add(new TtlExecutorTransformlet(disableInheritableForThreadPool));
            transformletList.add(new TtlForkJoinTransformlet(disableInheritableForThreadPool));
            if (isEnableTimerTask()) transformletList.add(new TtlTimerTaskTransformlet());
            // 下面2行代码就是改变了原来类的class字节码
            final ClassFileTransformer transformer = new TtlTransformer(transformletList);
            inst.addTransformer(transformer, true);
            logger.info("[TtlAgent.premain] addTransformer " + transformer.getClass() + " success");

            logger.info("[TtlAgent.premain] end");

            ttlAgentLoaded = true;
        } catch (Exception e) {
            String msg = "Fail to load TtlAgent , cause: " + e.toString();
            logger.log(Level.SEVERE, msg, e);
            throw new IllegalStateException(msg, e);
        }
    }
}

TtlExecutorTransformlet 就是怎么改线程池class内容的实现,我们来看下关键方法

public class TtlExecutorTransformlet implements JavassistTransformlet {
   @Override
    public void doTransform( final ClassInfo classInfo) throws IOException, NotFoundException, CannotCompileException {
        final CtClass clazz = classInfo.getCtClass();
        if (EXECUTOR_CLASS_NAMES.contains(classInfo.getClassName())) {
            for (CtMethod method : clazz.getDeclaredMethods()) {
                updateSubmitMethodsOfExecutorClass_decorateToTtlWrapperAndSetAutoWrapperAttachment(method);
            }

            if (disableInheritableForThreadPool) updateConstructorDisableInheritable(clazz);

            classInfo.setModified();
        } else {
            if (clazz.isPrimitive() || clazz.isArray() || clazz.isInterface() || clazz.isAnnotation()) {
                return;
            }
            if (!clazz.subclassOf(clazz.getClassPool().get(THREAD_POOL_EXECUTOR_CLASS_NAME))) return;

            logger.info("Transforming class " + classInfo.getClassName());

            final boolean modified = updateBeforeAndAfterExecuteMethodOfExecutorSubclass(clazz);
            if (modified) classInfo.setModified();
        }
    }
}

如果类是 java.util.concurrent.ScheduledThreadPoolExecutor , java.util.concurrent.ThreadPoolExecutor ,就处理修改字节码,否则不动。然后就是遍历每个方法修改方法的内容

private void updateSubmitMethodsOfExecutorClass_decorateToTtlWrapperAndSetAutoWrapperAttachment(@NonNull final CtMethod method) throws NotFoundException, CannotCompileException {
        final int modifiers = method.getModifiers();
        if (!Modifier.isPublic(modifiers) || Modifier.isStatic(modifiers)) return;
        // 这里主要在java.lang.Runnable构造时候调用com.alibaba.ttl.TtlRunnable#get()包装为com.alibaba.ttl.TtlRunnable
        // 在java.util.concurrent.Callable构造时候调用com.alibaba.ttl.TtlCallable#get()包装为com.alibaba.ttl.TtlCallable
        // 并且设置附件K-V为ttl.is.auto.wrapper=true
        CtClass[] parameterTypes = method.getParameterTypes();
        StringBuilder insertCode = new StringBuilder();
        for (int i = 0; i < parameterTypes.length; i++) {
            final String paramTypeName = parameterTypes[i].getName();
            if (PARAM_TYPE_NAME_TO_DECORATE_METHOD_CLASS.containsKey(paramTypeName)) {
                String code = String.format(
                        // decorate to TTL wrapper,
                        // and then set AutoWrapper attachment/Tag
                        "$%d = %s.get($%d, false, true);"
                                + "\ncom.alibaba.ttl.threadpool.agent.internal.transformlet.impl.Utils.setAutoWrapperAttachment($%<d);",
                        i + 1, PARAM_TYPE_NAME_TO_DECORATE_METHOD_CLASS.get(paramTypeName), i + 1);
                logger.info("insert code before method " + signatureOfMethod(method) + " of class " + method.getDeclaringClass().getName() + ": " + code);
                insertCode.append(code);
            }
        }
        if (insertCode.length() > 0) method.insertBefore(insertCode.toString());
    }

上面就是让java.util.concurrent.ThreadPoolExecutor和java.util.concurrent.ScheduledThreadPoolExecutor的字节码被增强,提交的java.lang.Runnable类型的任务会被包装为TtlRunnable,提交的java.util.concurrent.Callable类型的任务会被包装为TtlCallable,实现了无入侵无感知地嵌入TTL的功能。

感谢大家看到这里~

强烈推荐一套Java进阶博客,都是干货,走向架构师不是梦!

Java进阶全套博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值