在平时java开发中,如果想使用共享变量,往往使用public static 的方式修饰一个类的成员变量,这样就能实现变量共享了。不过,这样的变量是所有线程都共享的,有没有什么方式可以让这个变量只在某个线程中共享使用呢,答案是有的,可以使用ThreadLocal来解决这个问题。
一、ThreadLocal简介
ThreadLocal类并不是用来解决多线程环境下的共享变量问题,而是用来提供线程内部的共享变量,在多线程环境下,可以保证各个线程之间的变量互相隔离、相互独立。在线程中,可以通过get()/set()方法来访问变量。ThreadLocal实例通常来说都是private static类型的,它们希望将状态与线程进行关联。这种变量在线程的生命周期内起作用,可以减少同一个线程内多个函数或者组件之间一些公共变量的传递的复杂度。
二、ThreadLocal使用演示
首先看一下ThreadLocal常见的方法的方法:
- get()方法:获取与当前线程关联的ThreadLocal中的值。
- set(T value)方法:设置与当前线程关联的ThreadLocal中的值。
- remove()方法:将当前线程与ThreadLocal中关联的变量的值删除,可以加快垃圾回收,减少内存的占用。
- withInitial(Supplier<? extends S> supplier)方法:提供一个Supplier的lamda表达式用来当做初始值,java8引入。
下面,使用一段代码演示ThreadLocal是否实现了线程隔离:
public class ThreadLocalTest {
public static void main(String[] args) {
// 声明线程数量
int threadNum = 2;
for (int i = 0; i < threadNum; i++) {
new MyTask().start();
}
}
}
class MyTask extends Thread {
// 声明一个ThreadLocal
private static ThreadLocal<Integer> tl = new ThreadLocal<>();
@Override
public void run() {
// 获取线程名称
String threadName = Thread.currentThread().getName();
// 第一次获取此线程在ThreadLocal中的值
Integer num = tl.get();
System.out.println(threadName + "第一次在ThreadLocal中获取了值:" + num);
// 获取一个随机值
num = new Random().nextInt(10);
// 将获取的随机值设置到ThreadLocal中
tl.set(num);
System.out.println(threadName + "在ThreadLocal中设置了值:" + num);
try {
TimeUnit.MILLISECONDS.sleep(200);
} catch (InterruptedException e) {
e.printStackTrace();
}
// 再次获取此线程在ThreadLocal中的值
num = tl.get();
System.out.println(threadName + "第二次在ThreadLocal中获取了值:" + num);
try {
TimeUnit.MILLISECONDS.sleep(200);
} catch (InterruptedException e) {
e.printStackTrace();
}
// 调用remove方法,防止内存泄漏
tl.remove();
}
}
运行结果:
Thread-0第一次在ThreadLocal中获取了值:null
Thread-1第一次在ThreadLocal中获取了值:null
Thread-1在ThreadLocal中设置了值:5
Thread-0在ThreadLocal中设置了值:2
Thread-1第二次在ThreadLocal中获取了值:5
Thread-0第二次在ThreadLocal中获取了值:2
可以看到,在MyTask中,ThreadLocal变量声明成了static类型,但是不同线程设置值后,获取的也是彼此设置的值,并没有出现被覆盖的现象,可以看到使用ThreadLocal确实实现了线程隔离的效果。
三、ThreadLocal源码浅析
- 成员变量:
private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode = new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
成员变量比较简单,nextHashCode 和HASH_INCREMENT 的存在好像都是为了初始化threadLocalHashCode ,threadLocalHashCode 有什么用呢,这个就相当于当前ThreadLocal的hashcode。
- 构造函数:
public ThreadLocal() {
}
ThreadLocal的构造函数比较简单,只有一个空参的构造函数。
- set方法:
public void set(T value) {
// 获取当前线程对象
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
可以看到,首先调用了getMap方法,点进去看一下:
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
可以看到这是一个线程的成员变量:
ThreadLocal.ThreadLocalMap threadLocals = null;
如果获取threadLocals 不为null的话直接调用ThreadLocal.ThreadLocalMap的set方法把值设置进去,否则就调用createMap方法:
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
createMap方法创建了一个ThreadLocal.ThreadLocalMap。这里先看一下ThreadLocal.ThreadLocalMap的相关源码:
成员变量:
/**
* 初始容量
*/
private static final int INITIAL_CAPACITY = 16;
/**
* Entry表,大小为2的n次幂
*/
private Entry[] table;
/**
* table大小
*/
private int size = 0;
/**
* 扩容阈值
*/
private int threshold; // Default to 0
这里多提一下,为什么初始化table的大小都是2的幂呢?这里设置值其实和HashMap中原理一样,在后续的键值对存储和取出过程中:
- 使用位运算替代取模,提升计算效率。
- 为了使不同 hash 值发生碰撞的概率更小,尽可能促使元素在哈希表中均匀地散列。
ThreadLocalMap是用来存储与线程关联的value的哈希表,它具有HashMap的部分特性,比如容量、扩容阈值等,它内部通过Entry类来存储key和value,Entry类的定义为:
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
Entry继承自WeakReference,通过上述源码super(k)可以知道,ThreadLocalMap是使用ThreadLocal的弱引用作为Key的。这里顺带也回顾一下java中的4种引用:
- 强引用:只要某个对象与强引用关联,那么JVM在内存不足的情况下,宁愿抛出outOfMemoryError错误,也不会回收此类对象。
- 软引用:java中使用SoftRefence来表示软引用,如果某个对象与软引用关联,那么JVM只会在内存不足的情况下回收该对象。
- 弱引用:java中使用WeakReference来表示弱引用。如果某个对象与弱引用关联,那么当JVM在进行垃圾回收时,无论内存是否充足,都会回收此类对象。
- 虚引用:java中使用PhantomReference来表示虚引用。虚引用就像形同虚设一样,就像某个对象没有引用与之关联一样。若某个对象与虚引用关联,那么在任何时候都可能被JVM回收掉。
继续跟踪createMap,查看ThreadLocalMap创建过程:
(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
可以看到创建过程也比较简单:
- 首先初始化table ;
- 使用firstKey的hashcode与table的大小进行运算,计算出存储位置的下标;
- 创建一个Entry,指向table中下标的位置,将size同时置为1;
- 计算扩容阈值。
- get方法:
public T get() {
// 获取当前线程对象
Thread t = Thread.currentThread();
// 获取当前线程的ThreadLocalMap 对象
ThreadLocalMap map = getMap(t);
if (map != null) { // 如果map不为空,则调用getEntry获取Entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 如果当前线程的ThreadLocalMap 对象为null,则调用setInitialValue方法
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue(); // 获取初始化值
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 初始化当前线程对象的ThreadLocalMap
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
// 初始化值
protected T initialValue() {
return null;
}
调用get方法时,如果当前线程没有初始化过ThreadLocalMap ,则初始化,并且返回null,否则调用getEntry方法:
private Entry getEntry(ThreadLocal<?> key) {
// 使用ThreadLocal的hashcode与table大小计算出存储位置下标
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 如果获取的Entry对象不为空,并且此时ThreadLocal作为key的引用对象也并未被回收
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
// 如果e不为null
while (e != null) {
ThreadLocal<?> k = e.get();
// 判断当前ThreadLocal对象与e持有的引用对象是否同一个
if (k == key)
return e;
// 如果e的引用已经被回收了
if (k == null)
expungeStaleEntry(i);
else
// 否则就把下标下移,获取下一个entry对象
i = nextIndex(i, len);
e = tab[i];
}
// 如果entry为null,直接返回null
return null;
}
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 清理下标staleSlot的entry
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
// 重新hash一次table
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
通过get方法源码,可以看到,调用expungeStaleEntry这个方法,ThreadLocal内部会自动帮我们检测key的有效性和动态调整Entry的位置,这样可以有效防止内存泄漏。而ThreadLocal在set和get方法中都使用了expungeStaleEntry这个方法,那是不是说明ThreadLocal不存在内存泄漏呢,答案是否定的,因为我们不调用get和set方法时,还是可能存在内存泄漏问题的,所以我们还是需要每次手动调用remove方法,防止内存泄漏。
- remove方法:
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
remove方法比较简单,通过ThreadLocal的hashCode计算出下标,然后逐个找到key引用为当前ThreadLocal对象的Entry对象,然后一一清除引用就行了。
总结
通过源码可以看到,ThreadLocal可以实现不同线程之前的隔离,因为ThreadLocalMap是每个线程持有的对象,不过由于ThreadLocalMap使用ThreadLocal的弱引用作为key,如果一个ThreadLocal没有外部关联的强引用,那么在虚拟机进行垃圾回收时,这个ThreadLocal会被回收,这样,ThreadLocalMap中就会出现key为null的Entry,这些key对应的value也就无法访问了,此时除非线程被回收,否则线程持有的ThreadLocalMap中的变量也是不会被回收的,可能造成内存泄漏。
建议将ThreadLocal变量定义成private static的,这样的话ThreadLocal的生命周期就更长,由于一直存在ThreadLocal的强引用,所以ThreadLocal也就不会被回收,也就能保证任何时候都能根据ThreadLocal的弱引用访问到Entry的value值,然后remove它,可以防止内存泄露。