TransmittableThreadLocal原理解析

时间:2022-11-26 01:18:03

 

TransmittableThreadLocal是什么?

TransmittableThreadLocal(简称TTL)是alibaba提供的一个工具包中的类,主要作用就是解决线程池场景下的变量传递问题。继承自InheritableThreadLocal
主要用途:

  1. 链路追踪,如日志链路追踪(比如traceId在线程池中传递)
  2. 用户会话信息
  3. 中间件线程池信息传递。如Hystrix(不过hystrix自己实现了类似的一套)

依赖如下:

<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>transmittable-thread-local</artifactId>
    <version>2.11.4</version>
</dependency>

时序图

TransmittableThreadLocal原理解析

ThreadLocal、InheritableThreadLocal、TransmittableThreadLocal区别

  1. ThreadLocal:父子线程不会传递threadLocal副本到子线程中
  2. InheritableThreadLocal:在子线程创建的时候,父线程会把threadLocal拷贝到子线中(但是线程池的子线程不会频繁创建,就不会传递信息)
  3. TransmittableThreadLocal:解决了2中线程池无法传递线程本地副本的问题,在构造类似Runnable接口对象时进行初始化。

示例代码:

public static void main(String[] args) throws Exception {
    // 1. threadLocal测试
    // -- 输出结果:线程1null
    //            线程2null
    threadLocalTest();
    System.out.println("============================");
    System.out.println();

    // 2. ITL测试
    // -- 输出结果:子线程1我是主线程
    itlTest();
    // -- 输出结果:子线程1我是主线程
    //            子线程2我是主线程
    // -- 结论:InheritableThreadLocal只会在线程初始化的时候将父线程的值拷贝到子线程(仅拷贝一次)
    itlTestThreadPoolTest();
    System.out.println("============================");
    System.out.println();

    // 3. TTL测试
    // 输出结果:我是线程1:我是主线程
    //         修改主线程
    //         我是线程2:修改主线程
    // -- 结论:TTL能在线程池中传递
    ttlTest();
}
// TTL测试
private static void ttlTest() throws InterruptedException {
    TransmittableThreadLocal<String> local = new TransmittableThreadLocal<>();
    local.set("我是主线程");
    //生成额外的代理
    ExecutorService executorService = Executors.newFixedThreadPool(1);
    //**核心装饰代码!!!!!!!!!**
    executorService = TtlExecutors.getTtlExecutorService(executorService);
    CountDownLatch c1 = new CountDownLatch(1);
    CountDownLatch c2 = new CountDownLatch(1);
    executorService.submit(() -> {
        System.out.println("我是线程1:" + local.get());
        c1.countDown();
    });
    c1.await();
    local.set("修改主线程");
    System.out.println(local.get());
    executorService.submit(() -> {
        System.out.println("我是线程2:" + local.get());
        c2.countDown();
    });
    c2.await();
}
// ITL测试
private static void itlTestThreadPoolTest() {
    ThreadLocal<String> local = new InheritableThreadLocal<>();
    try {
        local.set("我是主线程");
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        CountDownLatch c1 = new CountDownLatch(1);
        CountDownLatch c2 = new CountDownLatch(1);
        //初始化init的时候,赋予了父线程的ThreadLocal的值
        executorService.execute(() -> {
            System.out.println("线程1" + local.get());
            c1.countDown();
        });
        c1.await();
        //主线程修改值
        local.set("修改主线程");
        //再次调用,查看效果
        executorService.execute(() -> {
            System.out.println("线程2" + local.get());
            c2.countDown();
        });
        c2.await();
        executorService.shutdownNow();
    } catch (InterruptedException e) {
        e.printStackTrace();
    } finally {
        //使用完毕,清除线程中ThreadLocalMap中的key。
        local.remove();
    }
}
private static void itlTest() throws InterruptedException {
    ThreadLocal<String> local = new InheritableThreadLocal<>();
    local.set("我是主线程");
    new Thread(() -> {
        System.out.println("子线程1" + local.get());
    }).start();
    Thread.sleep(2000);
}
// ThreadLocal测试
private static void threadLocalTest() {
    ThreadLocal<String> local = new ThreadLocal<>();
    try {
        local.set("我是主线程");
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        CountDownLatch c1 = new CountDownLatch(1);
        CountDownLatch c2 = new CountDownLatch(1);
        executorService.execute(() -> {
            System.out.println("线程1" + local.get());
            c1.countDown();
        });
        c1.await();
        executorService.execute(() -> {
            System.out.println("线程2" + local.get());
            c2.countDown();
        });
        c2.await();
        executorService.shutdownNow();
    } catch (InterruptedException e) {
        e.printStackTrace();
    } finally {
        //使用完毕,清除线程中ThreadLocalMap中的key。
        local.remove();
    }
}

TransmittableThreadLocal(下面统一简称TTL)

  • 解决线程池中,父子线程本地线程副本传递的问题

TTL使用方式(两种)

  1. 用TTL代码包裹让他生效
// 用TransmittableThreadLocal
TransmittableThreadLocal<Integer> ttl = new TransmittableThreadLocal<>();
// 1. 用TtlCallable.get(callable)包裹callable 或 用TtlRunnable.get(callable)包裹runnable
        Runnable runnable = () -> System.out.println("ttl信息:" + ttl.get());
        TtlRunnable.get(runnable);

// 2. 用TtlExecutors.getTtlExecutorService包裹
        ExecutorService ttlExecutorService = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(1));
        ttlExecutorService.execute(() -> System.out.println("ttl信息:" + ttl.get()));
  1. 用JavaAgent自动修改字节码(详细原理见另一篇:《002 TransmittableThreadLocal使用JavaAgent动态代理机制分析》)
  • 启动jar的时候,附加上参数-javaagent:/xx/transmittable-thread-local.jar(目录自己放好)
/opt/soft/jdk/jdk1.8.0_191/bin/java
        -javaagent:/xx/transmittable-thread-local.jar
        -jar /opt/xxx.jar

TTL原理

  • 设计模式上采用装饰器模式去增强Runnable等任务

举例TtlRunnable源码分析

步骤:

  1. 装饰Runnable,将主线程的TTL传入到TtlRunnable的构造方法中
  2. 将子线程的TTL的值进行备份,将主线程的TTL设置到子线程中(value是对象引用,可能存在线程安全问题);
  3. 执行子线程逻辑
  4. 删除子线程新增的TTL,将备份还原重新设置到子线程的TTL中
@Override
public void run() {
        /**
         * capturedRef是主线程传递下来的ThreadLocal的值。
         */
        Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
        }
        /**
         * 1.  backup(备份)是子线程已经存在的ThreadLocal变量;
         * 2. 将captured的ThreadLocal值在子线程中set进去;
         */
        Object backup = replay(captured);
        try {
        /**
         * 待执行的线程方法;
         */
        runnable.run();
        } finally {
        /**
         *  在子线程任务中,ThreadLocal可能发生变化,该步骤的目的是
         *  回滚{@code runnable.run()}进入前的ThreadLocal的线程
         */
        restore(backup);
        }
        }
/**
 * 将快照重做到执行线程
 * @param captured 快照
 */
public static Object replay(Object captured) {
// 获取父线程ThreadLocal快照
final Snapshot capturedSnapshot = (Snapshot) captured;
        return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }

/*****************************************************
 * 重放TransmittableThreadLocal,并保存执行线程的原值
 ****************************************************/
private static WeakHashMap<TransmittableThreadLocalCode<Object>, Object> replayTtlValues(WeakHashMap<TransmittableThreadLocalCode<Object>, Object> captured) {
        WeakHashMap<TransmittableThreadLocalCode<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocalCode<Object>, Object>();

        for (final Iterator<TransmittableThreadLocalCode<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocalCode<Object> threadLocal = iterator.next();

        // backup
        // 遍历 holder,从 父线程继承过来的,或者之前注册进来的
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        // 清除本次没有传递过来的 ThreadLocal,和对应值
        //  -- 第一点:可能会有因为 InheritableThreadLocal 而传递并保留的值
        //  -- 第二点:保证主线程set过的ThreadLocal不被传递过来。明确其传递是由业务代码控制,就是明确 set 过值的
        if (!captured.containsKey(threadLocal)) {
        iterator.remove();
        threadLocal.superRemove();
        }
        }

        // set TTL values to captured
        // 将 map 中的值,设置到快照
        // 内部调用了 beforeExecute 和 afterExecute 方法。默认不做任何处理
        setTtlValuesTo(captured);

        // call beforeExecute callback
        // TransmittableThreadLocal 的回调方法,在任务执行前执行
        doExecuteCallback(true);

        return backup;
        }

private static WeakHashMap<ThreadLocal<Object>, Object> replayThreadLocalValues( WeakHashMap<ThreadLocal<Object>, Object> captured) {
final WeakHashMap<ThreadLocal<Object>, Object> backup = new WeakHashMap<ThreadLocal<Object>, Object>();

        for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
        backup.put(threadLocal, threadLocal.get());

final Object value = entry.getValue();
        // 如果值是标记已删除,则清除
        if (value == threadLocalClearMark) threadLocal.remove();
        else threadLocal.set(value);
        }

        return backup;
        }
/*********************************************
 * 恢复备份的原快照
 *********************************************/
public static void restore( Object backup) {
// 将之前保存的TTL和threadLocal原来的数据覆盖回去
final Snapshot backupSnapshot = (Snapshot) backup;
        restoreTtlValues(backupSnapshot.ttl2Value);
        restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
        }

private static void restoreTtlValues( WeakHashMap<TransmittableThreadLocalCode<Object>, Object> backup) {
        // call afterExecute callback
        // 调用执行完后回调接口
        doExecuteCallback(false);

        // 移除子线程新增的TTL
        for (final Iterator<TransmittableThreadLocalCode<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocalCode<Object> threadLocal = iterator.next();
        // 恢复快照时,清除本次传递注册进来,但是原先不存在的 TransmittableThreadLocal
        // 移除掉所有不在备份里面的TTL数据,应该是为了避免内存泄漏吧
        // clear the TTL values that is not in backup
        // avoid the extra TTL values after restore
        if (!backup.containsKey(threadLocal)) {
        iterator.remove();
        threadLocal.superRemove();
        }
        }

        // 重置为原来的数据(就是恢复回备份前的值)
        // restore TTL values
        setTtlValuesTo(backup);
        }

private static void setTtlValuesTo( WeakHashMap<TransmittableThreadLocalCode<Object>, Object> ttlValues) {
        for (Map.Entry<TransmittableThreadLocalCode<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocalCode<Object> threadLocal = entry.getKey();
        // set 的同时,也就将 TransmittableThreadLocal 注册到当前线程的注册表了
        threadLocal.set(entry.getValue());
        }
        }

从上面TtlRunnable可见,TtlRunnable肯定是维护了一个线程级别的的缓存,每次调用run()前后进行set和还原数据。

线程池级别缓存

线程池缓存holder如下

/**
 * holder - 线程级别缓存
 * 1. 用WeakHashMap弱引用,为了避免内存泄漏,内存不足时弱引用自动被回收
 * 2. 使用InheritableThreadLocal,作用跟ThreadLocal差不多(因为replay设置值,run(),最后还是会restore还原)
 */
private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
        new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
@Override
protected WeakHashMap<TransmittableThreadLocalCode<Object>, ?> initialValue() {
        // holder默认使用InheritableThreadLocal。初始化的时候会调用initialValue返回一个WeekHashMap
        return new WeakHashMap<TransmittableThreadLocalCode<Object>, Object>();
        }

@Override
protected WeakHashMap<TransmittableThreadLocalCode<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocalCode<Object>, ?> parentValue) {
        // 返回的是子线程在第一次get的时候的初始值,如果不重写,默认就是返回父线程的值
        return new WeakHashMap<TransmittableThreadLocalCode<Object>, Object>(parentValue);
        }
        };

知道了TTL线程级别缓存是用holder进行存储的,但是如何进行拷贝的?

如何拷贝

  • TTL有一个静态内部类 Transmitter ,专门用于操作TTL本地线程缓存的重放、恢复备份、清除等操作。下面以TtlRunnable作为一个入口进行分析

步骤分析

  1. 通过TtlRunnable.get(runnable)进行增强调用
// 调用执行线程的时候包裹(TtlRunnable.get
executor.execute(TtlRunnable.get(runnable));
  1. 会调用到TtlRunnable的构造方法,然后调用到capture()拷贝方法
public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // capturedRef:拷贝副本的引用
        this.capturedRef = new AtomicReference<Object>(capture());
        // runnable待执行逻辑对象
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }
}

    /**************************************************
     * capture():拷贝副本
     * - 分为TTL拷贝、ThreadLocal拷贝
     **************************************************/
    public static Object capture() {
        // 抓取快照
        return new Snapshot(captureTtlValues(), captureThreadLocalValues());
    }
    /** 抓取 TransmittableThreadLocal 的快照 **/
    private static WeakHashMap<TransmittableThreadLocalCode<Object>, Object> captureTtlValues() {
        WeakHashMap<TransmittableThreadLocalCode<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocalCode<Object>, Object>();
        // 主线程和子线程其实都是共用一个holder的,所以主线程new一个TTL并做一个set操作之后,会搞一份数据put到holder中。
        // 这时候就可以进行一个副本的拷贝,遍历holder子线程的值,然后拷贝一份出来
        // eg:主线程用这个ttl.set("我是主线程");,这时候holder就会对应多了要给ttl,并且值是"我是主线程"
        for (TransmittableThreadLocalCode<Object> threadLocal : holder.get().keySet()) {
            // threadLocal.copyValue()默认还是拷贝引用
            ttl2Value.put(threadLocal, threadLocal.copyValue());
        }
        return ttl2Value;
    }
    /** 抓取 ThreadLocal 的快照 **/
    private static WeakHashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
        final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value = new WeakHashMap<ThreadLocal<Object>, Object>();
        // 从 threadLocalHolder 中,遍历注册的 ThreadLocal,将 ThreadLocal 和 TtlCopier 取出,将值复制到 Map 中
        for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
            final ThreadLocal<Object> threadLocal = entry.getKey();
            final TtlCopier<Object> copier = entry.getValue();
            // 默认拷贝的是引用
            threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
        }
        return threadLocal2Value;
    }

其他注意点

  1. TTL为什么不直接继承ThreadLocal?
  • 因为有些业务需要用到ITL特性,如果直接继承ThreadLocal,就会丢失ITL的父拷贝到子线程数据的特性(子线程创建时拷贝)
  1. 为什么需要在run执行完之后调用restore()?
  • restore里面会主动调用remove()回收,避免内存泄露(会删除子线程新增的TTL)
  • 不调用restore()的话,就会覆盖之前backup备份部分子线程的数据,这样可能在业务上有隐患
  1. TTL存在线程安全问题?
  • 存在的,因为默认都是引用类型拷贝,如果子线程修改了数据,主线程是可以感知到的
  1. TTL是否存在内存泄露问题?
  • TTL维护的holder本身是一个static来的,使用的时候会调用restore(),然后里面显式调用remove()清楚子线程新增TTL,所以正确使用下是没有内存泄露问题的