ThreadLocal源码

一、ThreadLocal 数据结构

在日常的开发中,我们使用 ThreadLocal 的方式如下:

public class ThreadLocalTest {
	// ThreadLocal variable to store each thread's ID
	private static final ThreadLocal<Integer> threadId = ThreadLocal.withInitial(() -> (int) (Math.random() * 1000));

	public static void main(String[] args) {
		// Create multiple threads
		Thread thread1 = new Thread(new Task(), "Thread-1");
		Thread thread2 = new Thread(new Task(), "Thread-2");
		Thread thread3 = new Thread(new Task(), "Thread-3");

		// Start the threads
		thread1.start();
		thread2.start();
		thread3.start();
	}

	// Task to be executed by each thread
	static class Task implements Runnable {
		@Override
		public void run() {
			// Get the thread's unique ID
			Integer id = threadId.get();
			System.out.println(Thread.currentThread().getName() + " has ID: " + id);

			// Do some work
			for (int i = 0; i < 5; i++) {
				System.out.println(Thread.currentThread().getName() + " working with ID: " + id);
				try {
					Thread.sleep(100);
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}

			// Remove the thread-local variable
			threadId.remove();
		}
	}
}

输出结果:

Thread-1 has ID: 911
Thread-1 working with ID: 911
Thread-3 has ID: 52
Thread-3 working with ID: 52
Thread-2 has ID: 687
Thread-2 working with ID: 687
Thread-3 working with ID: 52
Thread-1 working with ID: 911
Thread-2 working with ID: 687
Thread-3 working with ID: 52
Thread-1 working with ID: 911
Thread-2 working with ID: 687
Thread-2 working with ID: 687
Thread-1 working with ID: 911
Thread-3 working with ID: 52
Thread-2 working with ID: 687
Thread-3 working with ID: 52
Thread-1 working with ID: 911

Process finished with exit code 0

这一段代码含义:

  1. 定义了一个 ThreadLocal 变量为 threadId,对应的初始值为 0 到 1000 之间的随机数;
  2. 创建线程并启动,每个线程执行 Task 任务
  3. 在每个线程中使用 ThreadLocal 变量,每个线程从 ThreadLocal 的获取其唯一 ID 并打印出来;
  4. 模拟工作,例如操作数据库等;
  5. 移除 ThreadLocal 变量。

它们各个线程之间的数据不会出现错乱,那其底层是怎么实现的呢?其数据结构如下:
image.png

每一个 Thread 对象都有一个名为 threadLocals 的成员变量,对应的类型为 ThreadLocal.ThreadLocalMap。在 ThreadLocal.ThreadLocalMap 对象的内部有一个 Entry 数组,其中存储的 Entry 对象的 key 为 ThreadLocal,value 就是我们绑定到线程上的值。ThreadLocal 之所以可以做到线程间数据隔离,那是基于它每个线程内部都拥有一个独立的 ThreadLocalMap 对象,每个线程对自己的 ThreadLocalMap 对象进行操作不会影响到其他线程的数据。

二、源码学习

withInitial 设置初始值

/**
 * Creates a thread local variable. The initial value of the variable is
 * determined by invoking the {@code get} method on the {@code Supplier}.
 *
 * <p>创建线程局部变量。变量的初始值是通过调用Supplier上的get方法来确定的</p>
 *
 * @param <S> the type of the thread local's value 线程本地值的类型
 * @param supplier the supplier to be used to determine the initial value
 *                 <br>用于确认初始值的供应商
 * @return a new thread local variable 一个新的线程局部变量
 * @throws NullPointerException if the specified supplier is null
 * @since 1.8
 */
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
}

在上面的示例代码中就使用了 withInitial 进行设置 ThreadLocal 返回的初始值为 0 到 1000 之间的随机数。需要注意的是:如果该线程没有调用 set() 方法,那么第一次调用 get() 方法返回的就是这个设置的初始值,后续如果调用了 remove() 方法之后,再次调用 get() 方法也会返回这里设置的默认值(当然如果你设置的是随机数,那么返回的肯定不是同一个值)。

这里还有一个细节点就是使用了 Supplier,它有什么好处?先看看类的定义:

/**
 * Represents a supplier of results.
 * 代表结果供应商。

 *
 * <p>There is no requirement that a new or distinct result be returned each
 * time the supplier is invoked.
 * 不要求每次调用时都返回新的或独特的结果。这是一个功能接口,其功能方法是 get()。
 *
 * <p>This is a <a href="package-summary.html">functional interface</a>
 * whose functional method is {@link #get()}.
 *
 * @param <T> the type of results supplied by this supplier
 * 该供应商提供的结果类型代表结果供应商。
 *
 * @since 1.8
 */
@FunctionalInterface
public interface Supplier<T> {

    /**
     * Gets a result.
     *
     * @return a result
     */
    T get();
}

这里能看到它只有一个 get() 方法用于获取结果。

set(T Value) 向 ThreadLocal 设置值

源码:

/**
 * Sets the current thread's copy of this thread-local variable
 * to the specified value.  Most subclasses will have no need to
 * override this method, relying solely on the {@link #initialValue}
 * method to set the values of thread-locals.
 *
 * <p>将此线程本地变量的当前线程副本设置为指定值。大多数子类无需重载此方法,只需依赖 initialValue 方法来设置线程本地变量的值。
 *
 * @param value the value to be stored in the current thread's copy of
 *        this thread-local.
 * <p>当前线程的线程本地变量副本中要存储的值
 */
public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocal.ThreadLocalMap对象
    ThreadLocalMap map = getMap(t);
    // 如果ThreadLocalMap不为空,把key为自身ThreadLocal对象,value为要存放的值
    if (map != null) {
        map.set(this, value);
    } else {
        // 创建ThreadLocalMap,key为自身ThreadLocal对象,value为要存放的值
        createMap(t, value);
    }
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

这里的逻辑就比较简单了,先获取到当前线程,然后调用 getMap() 方法,内部就是返回当前线程的 threadLocals 对象也就是 ThreadLocal.ThreadLocalMap,如果获取到的 threadLocals 不为空,则把 key 为自身 ThreadLocal 对象,value 为要存放的值;如果获取到的 threadLocals 为空,则调用 createMap() 方法进行创建 ThreadLocalMap 对象并把 key 为自身 ThreadLocal 对象,value 为要存放的值。

ThreadLocalMap#Entry

在上面的 set() 方法可以看到数据要么是调用 map.set() 或者 createMap() --> new ThreadLocalMap(this, firstValue) 方法,但是无论调用那种方式,最终数据都是保存到 Entry 数组里。那就来看看 Entry 的定义:

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

这里看到 Entry 继承了 WeakReference 对象,这里有必须要提一下 Java 的四种引用,毕竟与 GC 有关,GC 会根据可达性分析算法,沿着 GC Roots 一直找到对象没有被其他对象引用,那代表该对象可以被 GC,但是这里的引用又分为四种:

  • 强引用:普通的变量引用
public static User user = new User();
  • 软引用:将对象用 SoftReference 软引用类型的对象包裹,正常情况不会被回收,但是 GC 做完后发现释放不出空间存放新的对象,则会把这些软引用的对象回收掉。软引用可用来实现内存敏感的高速缓存。
public static SoftReference<User> user = new SoftReference<User>(new User());

软引用在实际中有重要的应用,例如浏览器的后退按钮🔘。按后退时,这个后退时显示的网页内容是重新进行请求还是从缓存中取出呢?这就要看具体的实现策略了。

  1. 如果一个网页在浏览结束时就进行内容的回收,则按后退查看前面浏览过的页面时,需要重新构建。
  2. 如果将浏览过的网页存储到内存中会造成内存的大量浪费,甚至会造成内存溢出。
  • 弱引用:将对象用 WeakReference 软引用类型的对象包裹,弱引用跟没引用查不到,GC 会直接回收掉,很少用。
public static WeakReference<User> user = new WeakReference<User>(new User());
  • 虚引用:虚引用也称为幽灵引用或者幻影引用,它是最弱的一种引用关系,几乎不用。

那为什么这里 Entry 保存 ThreadLocal 类型的 key 需要使用弱引用?其实也是为了防止 OOM,如果 ThreadLocal 作为 Key 不使用弱引用,那么当前线程并没有结束,可以通过当前线程关联到其 threadLocals 属性对应的 ThreadLocalMap,再关联到 Entry 中的 ThreadLocal 对象,这时候 ThreadLocal 就无法被回收。

我们可以启动一个线程执行死循环,然后在里面无限创建 ThreadLocal 并设置值,看看这一段代码会不会出现 OOM。

public static void test() {
    new Thread(() -> {
        while (true) {
            ThreadLocal<String> threadLocal = new ThreadLocal<>();
            threadLocal.set("1");
        }
    }).start();
} 

最终发现这一段代码并不会出现 OOM,这是因为 ThreadLocal 是弱引用指向,在发生 GC 的时候就会被回收。
image.png

细心的小伙伴会发现这里其实还有一个问题,虽然 ThreadLocal 被回收了,但是 Entry 数组一直在塞入 Entry,回收之后就相当于 Enrty 的 key 为 null,value 还是有值的,那么为什么不会出现 OOM 呢?原因就是万 ThreadLocalMap 塞入元素的时候,会删除过时的元素(也就是 Entry 中的 key 弱引用持有的 ThreadLocal 为 null 的元素)。

ThreadLocalMap 构造方法
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    // 初始化容量为16
    table = new Entry[INITIAL_CAPACITY];
    // hash散列获取下标
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 存入Entry数组中
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    // 设置下一次扩容阈值 threshold = INITIAL_CAPACITY * 2 / 3;
    setThreshold(INITIAL_CAPACITY);
}

这里就只需要留意两个点即可:

  • 初始化 Entry 数组容量为 16;
  • 设置的扩容阈值为 Entry 数组长度的 2/3。

ThreadLocalMap#set(ThreadLocal<?> key, Obejct value)

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         // 寻址:+1往后寻找,环状的遍历数组
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        // 相同的key直接覆盖
        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

从上面来看 ThreadLocalMap#set 方法可以分为两步:

  1. 从开放寻址的方式找到合适的位置进行存储数据;
  2. 向数组中放入新的 Entry,有需要的话会进行扩容。
开放寻址的方式找到合适的位置进行存储数据

这一步的代码是在这个 for 循环里面:

Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
     e != null;
     // 寻址:+1往后寻找,环状的遍历数组
     e = tab[i = nextIndex(i, len)]) {
    ThreadLocal<?> k = e.get();

    // 相同的key直接覆盖
    if (k == key) {
        e.value = value;
        return;
    }

    if (k == null) {
        replaceStaleEntry(key, value, i);
        return;
    }
}

对应的 nextIndex 代码如下:

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

这是啥意思呢?分析如下所示:
image.png

总结:

  1. ThreadLocalMap#set 方法通过 for 循环遍历,直到在 Entry 数组中找到一个为 null 的位置或者是两个 key 相同进行覆盖原来的值才会跳出 for 循环;
  2. 第一个 if 意味着相同的 ThreadLocal 会进行覆盖之前的旧值;
  3. 第二个 if 意味着原来霸占着 Entry 数组这个位置的 ThreadLocal 因为是弱引用类型从而被 GC 垃圾回收了,但是为了避免对应的 value 值没有被回收导致内存泄漏的问题,这里就通过 replaceStaleEntry 进行覆盖。

向数组中放入新的 Entry,有需要的话会进行扩容

在上面 for 循环中进行的条件是 e != null,e 是 Enrty 数组中的元素,那么 for 循环结束的时候,除了成功覆盖了原有的元素的情况,还有就是找到了一个可以使用的位置。此时逻辑如下:

// 插入的新的Enrty
tab[i] = new Entry(key, value);
// 数组大小++
int sz = ++size;
// 调用cleanSomeSlots尝试清理一些未在使用的位置,如果清理失败则调用rehash进行扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
    rehash();

我们能看到进行 rehash 需要满足两个条件:

  1. cleanSomeSlots 方法进行清理一些未在使用的位置失败,也就代表着在 Entry 数组中没有因为 GC 而导致的 ThreadLocal 被清除的情况;
  2. 数组的大小大于扩容的阈值(在上面我们说过扩容的阈值为 Entry 数组容量的 2/3)

感兴趣的小伙伴可以自己去看看 cleanSomeSlots 的代码,这里就不分析了。下面就来看看 rehash 的代码:

private void rehash() {
    // 删除过时的条目
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    // 使用较低的加倍阈值以避免滞后
    if (size >= threshold - threshold / 4)
        resize();
}

/**
 * Double the capacity of the table.
 */
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    // 扩容为原来的两倍
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            // 如果对应的ThreadLocal被回收了,那么对应的value也进行删除
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // 开放式寻址找到合适的位置进行存放
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    // 修改扩容阈值
    setThreshold(newLen);
    // 设置size为有效的个数
    size = count;
    table = newTab;
}

总结:

  1. 先清除过时的条目;
  2. 进行扩容为原来的两倍,并把旧的数据从旧的 Entry 数组拷贝到新的 Enrty 数组中:
    1. 如果 key 为 null(也就是 ThreadLocal 被 GC 回收了),那么把对应的 value 值也设置为 null(从而没有地方引用了对应的 vaule),方便 GC 回收;
    2. 通过开放式寻址的方式找到合适的位置存放对应的元素。

ThreadLocal#get() 获取当前线程绑定到 ThreadLocal 上的值

public T get() {
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 已当前ThreadLocal作为key找到对应存放的值
        ThreadLocalMap.Entry e = map.getEntry(this);
        // 如果不为空,则直接返回
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 如果没有找到,或者当前线程的ThreadLocalMap为null,调用setInitialValue返回默认值
    return setInitialValue();
}

这部分代码可以分成两部分:

  1. 获取 ThreadLocalMap 中的值;
  2. 如果没有设置 ThreadLocalMap,则调用 setInitialValue 进行返回默认值。
获取 ThreadLocalMap 中的值

我们从上面的代码能看到通过 ThreadLocal 为 key 进行获取 ThreadLocalMap 的值的代码在 getEntry 中:

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 如果获取的e不为空,并且对应的key为当前ThreadLocal,那直接返回对应值
    if (e != null && e.get() == key)
        return e;
    else
        // 如果没有值,或者ThreadLocal和当前的key不相等,那么调用开放性寻址的方式去寻找
        return getEntryAfterMiss(key, i, e);
}

// getEntryAfterMiss
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        // 获取到对应的ThreadLocal
        ThreadLocal<?> k = e.get();
        // 和当前的ThreadLocal相等则返回
        if (k == key)
            return e;
        // 如果ThreadLocal已经被GC回收,则清理被回收的元素
        if (k == null)
            expungeStaleEntry(i);
        else
            // 开放性寻址,寻址在一个位置
            i = nextIndex(i, len);
        // 重新获取值
        e = tab[i];
    }
    // 如果还是没找到,那就代表之前没有调用set方法存放过值
    return null;
}

总结:

  1. 先获取到当前 ThreadLocal 在 Entry 数组中存放的位置 i;
  2. 通过下标获取到对应的 Entry,如果 Entry 不为空并且对应的 key 等于当前 ThreadLocal,则直接返回对应的 Entry;
  3. 如果没有值或者对应的 Entry 的 key 不等于当前 ThreadLocal,那么调用 getEntryAfterMiss 方法进行开放性寻址的方式进行获取对应的值;
    1. 判断当前 Entry 的 key 是否等于当前 ThreadLocal,相等则直接返回;
    2. 如果 ThreadLocal 已经被 GC 回收,则清理被回收的元素;
    3. 还找不到,则继续通过开放性寻址的方式进行获取下一个下标,从而获取到对应的 Entry 进行走上面的逻辑;
    4. 最后还是找不到则返回 null,代表之前就没有调用 set 方法进行存储过值。

如果没有设置 ThreadLocalMap

如果没有设置 ThreadLocalMap,则调用 setInitialValue 进行返回默认值。

private T setInitialValue() {
    // 子类可以自定义实现这个类用于设置默认值,默认情况下可以使用SuppliedThreadLocal这个类
    // 也就是直接使用ThreadLocal#withInitial方法进行设置默认值
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
    if (this instanceof TerminatingThreadLocal) {
        TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
    }
    return value;
}

这里子类可以直接去实现 ThreadLocal 然后重写 initialValue 方法,用于返回默认值。这里有一个默认的实现类 SuppliedThreadLocal:

static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

    private final Supplier<? extends T> supplier;

    SuppliedThreadLocal(Supplier<? extends T> supplier) {
        this.supplier = Objects.requireNonNull(supplier);
    }

    @Override
    protected T initialValue() {
        return supplier.get();
    }
}

看到构造方法入参是一个 Supplier,也就是可以按照下面的方式进行设置默认值:

private static final ThreadLocal<Integer> threadId = ThreadLocal.withInitial(() -> 100);

这里示例代码是设置每个 ThreadLocal 的默认值为 100,当然你可以根据实际业务情况设置。

  • 18
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值