源码
package java.lang;
import jdk.internal.misc.TerminatingThreadLocal;
import java.lang.ref.*;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
public class ThreadLocal<T> {
// 得到哈希值
private final int threadLocalHashCode = nextHashCode();
// 定义一个原子操作的integer类nextHashCode
private static AtomicInteger nextHashCode = new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
// 原子类nextHashCode自增指定值, 保证hash值均匀分布在2的次方的位置
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
// 初始值null
protected T initialValue() {
return null;
}
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}
public ThreadLocal() {
}
// 获取当前线程绑定的局部变量
public T get() {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取当前线程维护的ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
// 以当前ThreadLocal为key获取entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked") //压制警告
// 获取当前线程的局部变量并返回
T result = (T)e.value;
return result;
}
}
// 否则返回初始值
return setInitialValue();
}
boolean isPresent() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
return map != null && map.getEntry(this) != null;
}
// 设置初始值
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);
} else {
createMap(t, value);
}
if (this instanceof TerminatingThreadLocal) {
TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
}
return value;
}
// 绑定当前线程的局部变量
public void set(T value) {
//获取当前西线程
Thread t = Thread.currentThread();
//获取该线程维护的ThreadLoalMap对象
ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);
} else {
// 为该线程创建ThreadLocal对象
createMap(t, value);
}
}
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null) {
m.remove(this);
}
}
// 返回线程的threadLocals变量
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
// 当threadLocals不存在时,创建并初始化一个ThreadLocalMap类的实例赋给threadLocals
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}
// ThreadLocal不支持继承性,子线程禁止方法父线程的变量, InheritableThreadLocal继承ThreadLocal并重写了该方法,使子线程可以访问父线程的变量
T childValue(T parentValue) {
throw new UnsupportedOperationException();
}
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
@Override
protected T initialValue() {
return supplier.get();
}
}
// 每个Thread维护一个ThreadLocalMap,ThreadLocalMap中存储的是一个Entry[] table数组
static class ThreadLocalMap {
// Entry的key是弱引用,存储ThreadLocal对象,value是强引用,存储ThreadLocal中设置的值,保存每个Thread独立的数据副本
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
// Entry[] table数组的初始容量,必须是2的次方
private static final int INITIAL_CAPACITY = 16;
private Entry[] table;
private int size = 0;
// 负载因子,默认75%(len*2/3)
private int threshold;
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
// 找到下一个索引
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
// 找到上一个索引
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
// 使用ThreadLocalMap实例来初始化一个新的ThreadLocalMap对象
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];
for (Entry e : parentTable) {
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
//
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
// 获取key的索引位置i对应的ThreadLocalMap中的值
private Entry getEntry(ThreadLocal<?> key) {
// 索引i的计算方法: key的threadLocalHashCode & 数组长度-1
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
// 如果找到了,直接返回这个键值对
return e;
else
// 没有找到,从i开始往下继续找
return getEntryAfterMiss(key, i, e);
}
// 从i开始往下遍历,继续找key
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
// 如果找到了,直接返回这个键值对
return e;
if (k == null)
// 如果发现空key, 去除对相应value的引用,并在table中清除这个键值对
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
// set 键值对
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 先遍历一下key是否存在
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 如果存在,更新key的value
if (k == key) {
e.value = value;
return;
}
//如果遇到空key,该key-value替换这个键值对
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 如果key不存在, 新建entry,并插入到table中
tab[i] = new Entry(key, value);
int sz = ++size;
// 判断是否需要扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
// 进行全表的废弃数据的清除,并判断是否扩容
rehash();
}
// 删除key和对应的value
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
// 替换key的旧数据为value,从指定索引staleSlot搜索key
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// 寻找前面第一个被回收key的entry的索引,认为是废弃数据的索引
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
//
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
// 如果k=key,直接更新key的value
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// 如果没有找到key, 在staleSlot之前也没有废弃数据,就将staleSlot置为从staleSlot开始找到的第一个空key的entry的索引
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 先解除废弃数据对value的引用,使GC能回收它
tab[staleSlot].value = null;
// 更新
tab[staleSlot] = new Entry(key, value);
// 如果有其他的废弃数据, 清除
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
// 清除staleSlot索引的entry
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
// 继续清除后面的废弃数据
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
//先清理废弃数据,再重新判断是否需要扩容
private void rehash() {
expungeStaleEntries();
if (size >= threshold - threshold / 4)
resize();
}
// table 2倍扩容,
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (Entry e : oldTab) {
if (e != null) {
ThreadLocal<?> k = e.get();
// 扩容的时候再判断时候又被回收key的废弃数据
if (k == null) {
e.value = null;
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
// 循环清除废弃数据
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
}
}
1. ThreadLocal是除了加锁之外的保证线程安全的方法,对于共享对量,ThreadLocal为线程提供局部变量,每个线程对共享变量访问的时候都是访问线程自己的局部变量,不同的线程互不干扰,起到线程隔离的作用。并且变量在当前线程内起作用,避免同一个线程在不同组件之间一些公共资源传递的麻烦;
2. ThreadLocal维护一个ThreadLocalMap, 底层是一个Entry[] 数组, Entry继承弱引用,key弱引用指向为当前线程维护的ThreadLocal对象,value为强引用指向ThreadLocal对象绑定的数据;
3. key对ThreadLocal对象弱引用保证当线程使用完ThreadLocal对象,heap上的ThreadLocal对象能够被GC及时回收,因为key被线程的map引用,声明周期和线程一样长,如果key对ThreadLocal对象是强引用,那么除非使用完ThreadLocal对象手动调用remove方法清除key对key对ThreadLocal对象的引用,否则ThreadLocal对象要等到线程结束才会被回收;
4. Entry[] 数组初始容量为16,根据哈希冲突的解决方法,最好设置为2的次方。 负载因子=0.75, 当数组大小超过容量75%,则进行一次全表的null key清除,若还超过容量75%,则进行2倍扩容;
5. set, get,remove在搜索key时都会清除key为null的entry, 释放已经被回收了的ThreadLocal对象绑定的value;
6. ThreadLocal为当前线程调用,不支持继承性,父线程绑定的局部变量,子线程不能访问。InheritableThreadLocal类可以让子线程获取到父线程的局部变量。
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
// 重写ThreadLocal的childValue,返回父线程的数据
protected T childValue(T parentValue) {
return parentValue;
}
// 重写getMap方法,不再返回线程的threadLocals变量,而是返回inheritableThreadLocals变量
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
// 同样要重写createMap,当map不存在时,给线程创建并初始化inheritableThreadLocals变量
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}
demo:
public class TestThreadLocal {
static String logstr;
static int count;
public static void main(String[] args) throws InterruptedException {
// InheritableThreadLocal可以获取当前线程的父线程的局部变量
// ThreadLocal<String> stringThreadLocal=new ThreadLocal<>();
// ThreadLocal<Integer> intThreadLocal=new ThreadLocal<>();
InheritableThreadLocal<String> stringInheritableThreadLocal = new InheritableThreadLocal<>();
InheritableThreadLocal<Integer> intInheritableThreadLocal = new InheritableThreadLocal<>();
for(int i=0;i<10;i++){
new Thread(() -> {
logstr=Thread.currentThread().getName();
count++;
// stringThreadLocal.set(logstr);
// intThreadLocal.set(count);
stringInheritableThreadLocal.set(logstr);
intInheritableThreadLocal.set(count);
// 模拟其他工作
try {
Thread.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("logstr:"+stringInheritableThreadLocal.get()+", count:"+intInheritableThreadLocal.get());
// 子线程
new Thread(()->{
// System.out.println("子线程"+Thread.currentThread().getName()+"的父线程数据 -> logstr:"+stringThreadLocal.get()+", count:"+intThreadLocal.get());
System.out.println("子线程"+Thread.currentThread().getName()+"的父线程数据 -> logstr:"+stringInheritableThreadLocal.get()+", count:"+intInheritableThreadLocal.get());
}).start();
}).start() ;
}
}
}