写在开头
在 Java 的多线程模块中,ThreadLocal是经常被提问到的一个知识点。
1.什么是 ThreadLocal
总结:什么是ThreadLocal? 可以概括为以下几个方面。
●它能让线程拥有了自己内部独享的变量
● 每一个线程可以通过get、set方法去进行操作
● 可以覆盖initialValue方法指定线程独享的值
● 通常会用来修饰类里private static final的属性,为线程设置一些状态信息,例如user ID或者Transaction ID
● 每一个线程都有一个指向threadLocal实例的弱引用
,只要线程一直存活或者该threadLocal实例能被访问到,都不会被垃圾回收清理掉
private static final ThreadLocal<T> threadLocal = new ThreadLocal<T>();
ThreadLocal 为我们提供了线程安全的另一种思路,我们平常说的线程安全主要是保证共享数据的并发访问问题,通过sychronized锁或者CAS无锁策略来保证数据的一致性。
2.ThreadLocal源码
2.1 实现思路
2.2 分析前你该知道
ThreadLocalMap 规定了 table 的大小必须是2的幂次方
/**
* The initial capacity -- MUST be a power of two.
*/
private static final int INITIAL_CAPACITY = 16;
2.3 源码分析
接下来就从这三个方法入手,来了解 ThreadLocal 的源码实现。
2.3.1.set(T value)
将set(T value)源码之前先了解一下可能出现的参数
1.1 ThreadLocalMap类的声明
static class ThreadLocalMap {
2
3 // hash map中的entry继承自弱引用WeakReference,指向threadLocal对象
4 // 对于key为null的entry,说明不再需要访问,会从table表中清理掉
5 // 这种entry被成为“stale entries”
6 static class Entry extends WeakReference<ThreadLocal<?>> {
7 /** The value associated with this ThreadLocal. */
8 Object value;
9
10 Entry(ThreadLocal<?> k, Object v) {
11 super(k);
12 value = v;
13 }
14 }
15
16 private static final int INITIAL_CAPACITY = 16;//必须是2的幂次方
17
18 private Entry[] table;//是一个Entry[]数组
19
20 private int size = 0;
21
22 private int threshold; // Default to 0 扩容使用
23
24 private void setThreshold(int len) {
25 threshold = len * 2 / 3;
26 }
27
28 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
29 table = new Entry[INITIAL_CAPACITY];
30 int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
31 table[i] = new Entry(firstKey, firstValue);
32 size = 1;
33 setThreshold(INITIAL_CAPACITY);
34 }
35}
1.2 set(T value) 方法源码分析
public void set(T value) {
//获取当前线程(调用者线程)
Thread t = Thread.currentThread();
//以当前线程作为key值,去查找对应的线程变量,找到对应的Map
ThreadLocalMap map = getMap(t); //返回来的是一个 ThreadLocal.ThreadLocalMap对象
//如果map不等于null,就直接添加本地变量,key为当前线程,值为要添加的变量值
if (map != null)
//在下面1.3中分析
map.set(this, value);
//如果 map == null,说明是首次添加,需要首先创建对应的Map
else
//创建Map方法,向下看
createMap(t, value);
}
void createMap(Thread t, T firstValue) {
//使用构造器的方式创建Map,源码分析继续向下看
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
//初始化table
table = new Entry[INITIAL_CAPACITY];
//通过公式计算得到下标
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
//当前线程为key,值为value,组装Entry后,赋值到table数组某个坑位中
table[i] = new Entry(firstKey, firstValue);
size = 1;
//扩容相关
setThreshold(INITIAL_CAPACITY);
}
1.3 map.set(this, value);方法实现
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
//获取table下标
int i = key.threadLocalHashCode & (len-1);
//for循环,循环遍历判断当前坑位是否有值,有值的话开始比较,key相同的话,值覆盖;key为空的话,赋值;
//key不相同的话,使用nextIndex()方法,下标 i+1,继续判断坑位是否为空,为空赋值,不为空继续判断,直到扩容(此处不介绍扩容)
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
//因为Entry使用的是弱引用,在某些情况下,它会被JVM任务是无效引用而回收,所以k可能为null
//(Entry弱引用介绍,请继续向下读)
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
//nextIndex 算法
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
set源码分析小结
2.3.2 get()
如果你已经理解了set(T value)方法的实现,接下来的get()方法就更简单了。
2.1 get() 源码基础实现
public T get() {
//获取当前线程
Thread t = Thread.currentThread();
//从当前线程中获取到 ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
//从ThreadLocalMap中,根据key找出当前线程所对应的Entry
//(具体实现方法介绍,参考下文2.2)
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
//如果Entry不为空,直接返回value值
T result = (T)e.value;
return result;
}
}
//否则,调用setInitialValue()方法,设置初始值并返回(在Entry[]数组上指定下标,设置值为null)
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();//此处initialValue()返回为 null,所以默认value为null
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
//set()/creatMap()方法,同之前介绍的一样,不再介绍
map.set(this, value);
else
createMap(t, value);
return value;
}
2.2 map.getEntry(this);方法实现
private Entry getEntry(ThreadLocal<?> key) {
//获取Entry[]数组下标
int i = key.threadLocalHashCode & (table.length - 1);
//找到指定Entry
Entry e = table[i];
//Entry不为空,并且key==当前线程
if (e != null && e.get() == key)
//直接返回当前Entry
return e;
else
//反之调用 getEntryAfterMiss()
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
//Entry不为空
while (e != null) {
ThreadLocal<?> k = e.get();
//key==当前线程,返回当前 Entry
if (k == key)
return e;
//key为空,重新rehash(因为Entry使用的是弱引用,在某些情况下,它会被JVM任务是无效引用而回收,所以需要重新rehash。此处不做分析)
if (k == null)
expungeStaleEntry(i);
else//否则,下标+1,继续遍历查找
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
为了便于理解,笔者特地画了一个时序图,请看:
get方法时序图
get源码分析小结
2.3.3.remove()
remove() 实现,也是比较简单的
3.1 remove() 源码基础实现
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);//源码见下面
}
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
就一句话,获取当前线程内部的ThreadLocalMap,存在则从map中删除这个ThreadLocal对象。
2.3.3 再谈引用
3.ThreadLocal使用场景
Zuul核心原理是什么?就是将请求放入过滤器链中经过一个个过滤器的处理,过滤器之间没有直接的调用关系,处理的结果都是存放在RequestContext里传递的,而这个RequestContext就是一个ThreadLocal类型的对象啊!!!
public class RequestContext extends ConcurrentHashMap<String, Object> {
2
3 protected static final ThreadLocal<? extends RequestContext> threadLocal = new ThreadLocal<RequestContext>() {
4 @Override
5 protected RequestContext initialValue() {
6 try {
7 return contextClass.newInstance();
8 } catch (Throwable e) {
9 throw new RuntimeException(e);
10 }
11 }
12 };
13
14 public static RequestContext getCurrentContext() {
15 if (testContext != null) return testContext;
16
17 RequestContext context = threadLocal.get();
18 return context;
19 }
20}
以Zuul中前置过滤器DebugFilter为例:
1public class DebugFilter extends ZuulFilter {
2
3 @Override
4 public Object run() {
5 // 获取ThreadLocal对象RequestContext
6 RequestContext ctx = RequestContext.getCurrentContext();
7 // 它是一个map,可以放入数据,给后面的过滤器使用
8 ctx.setDebugRouting(true);
9 ctx.setDebugRequest(true);
10 return null;
11 }
12}
4.ThreadLocal使用示例
4.1 使用共享变量方式
/**
* TODO 多线程示例
*
* @author liuzebiao
* @Date 2020-5-8 11:32
*/
public class ThreadLocalDemo {
// 1.使用原子类保证多线程原子操作
// public static AtomicInteger num = new AtomicInteger(0);
// 多线程共享变量num
private static int num = 1;
/**
* 线程方法
*/
static class ThreadDemo extends Thread{
private String name;
public ThreadDemo(String name) {
this.name = name;
}
@Override
public void run() {
for (int i = 0; i < 3; i++) {
try {
System.out.println(name + "------------>" + num++);
// System.out.println(name + "------------>" + num.addAndGet(1));
TimeUnit.SECONDS.sleep(3);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
/*@Override
public void run() {
for (int i = 0; i < 3; i++) {
try {
//2.使用synchronzied加锁,解决多线程原子操作
synchronized (ThreadLocalDemo.class) {
System.out.println(name + "------------>" + num++);
TimeUnit.SECONDS.sleep(3);
}
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}*/
}
public static void main(String[] args) {
ExecutorService executor = Executors.newCachedThreadPool();
for (int i = 0; i < 3; i++) {
ThreadDemo task01 = new ThreadDemo("线程"+i);
executor.execute(task01);
}
executor.shutdown();
}
}
4.2 使用ThreadLocal方式
/**
* TODO 定义 THREAD_LOCAL
*
* @author liuzebiao
* @Date 2020-5-8 11:13
*/
public class THREAD_LOCAL {
// 定义threadLocal
public static ThreadLocal<Integer> threadLocal = new ThreadLocal<Integer>();
/**
* set()方法
*/
public static void set(Integer value){
threadLocal.set(value);
}
/**
* get()方法
*/
public static Integer get() {
return threadLocal.get();
}
/**
* remove()方法
*/
public static void remove(){
threadLocal.remove();
}
}
/**
* TODO 多线程使用ThreadLocal示例
*
* @author liuzebiao
* @Date 2020-5-7 14:30
*/
public class ThreadLocalDemo01 {
//直接在类中定义ThreadLocal,也是OK的
//private static final ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
/**
* 线程方法
*/
static class ThreadDemo extends Thread{
private Integer num;
public ThreadDemo(Integer num) {
this.num = num;
}
@Override
public void run() {
//建议及时remove()
//尤其是使用线程池,一定要remove(),否则会由于线程重用而导致的数据不同的异常
THREAD_LOCAL.remove();
System.out.println(Thread.currentThread().getName()+ ":默认值为:"+THREAD_LOCAL.get());
THREAD_LOCAL.set(num);
System.out.println(Thread.currentThread().getName()+ ":设置后,值为:"+THREAD_LOCAL.get());
}
}
public static void main(String[] args) throws InterruptedException {
//使用new Thread().start() 的方式启动线程
//线程1
new Thread(new ThreadDemo(100)).start();
//休眠5s
TimeUnit.SECONDS.sleep(5);
//线程2
new Thread(new ThreadDemo(200)).start();
}
}
5.谈谈ThreadLocal的设计与不足
一个内存泄漏的例子:
1public class MemoryLeak {
2
3 public static void main(String[] args) {
4 new Thread(new Runnable() {
5 @Override
6 public void run() {
7 for (int i = 0; i < 1000; i++) {
8 TestClass t = new TestClass(i);
9 t.printId();
10 t = null;
11 }
12 }
13 }).start();
14 }
15
16 static class TestClass{
17 private int id;
18 private int[] arr;
19 private ThreadLocal<TestClass> threadLocal;
20 TestClass(int id){
21 this.id = id;
22 arr = new int[1000000];
23 threadLocal = new ThreadLocal<>();
24 threadLocal.set(this);
25 }
26
27 public void printId(){
28 System.out.println(threadLocal.get().id);
29 }
30 }
31}
运行结果:
10
21
32
43
5...省略...
6440
7441
8442
9443
10444
11Exception in thread "Thread-0" java.lang.OutOfMemoryError: Java heap space
12 at com.gentlemanqc.MemoryLeak$TestClass.<init>(MemoryLeak.java:33)
13 at com.gentlemanqc.MemoryLeak$1.run(MemoryLeak.java:16)
14 at java.lang.Thread.run(Thread.java:745)
对上述代码稍作修改,请看:
1public class MemoryLeak {
2
3 public static void main(String[] args) {
4 new Thread(new Runnable() {
5 @Override
6 public void run() {
7 for (int i = 0; i < 1000; i++) {
8 TestClass t = new TestClass(i);
9 t.printId();
10 t.threadLocal.remove();
11 }
12 }
13 }).start();
14 }
15
16 static class TestClass{
17 private int id;
18 private int[] arr;
19 private ThreadLocal<TestClass> threadLocal;
20 TestClass(int id){
21 this.id = id;
22 arr = new int[1000000];
23 threadLocal = new ThreadLocal<>();
24 threadLocal.set(this);
25 }
26
27 public void printId(){
28 System.out.println(threadLocal.get().id);
29 }
30 }
31}
运行结果:
10
21
32
43
5...省略...
6996
7997
8998
9999
一个内存泄漏,一个正常完成,对比代码只有一处不同:t = null改为了t.threadLocal.remove();哇,究竟是什么原因呢?神奇的remove!!!
内存溢出问题解答
至此,该做的铺垫都已经完成了,此时,我们可以来看看上面那个内存泄漏的例子。示例中执行一次for循环里的代码后,对应的内存状态:
调用t=null后,虽然无法再通过t访问内存地址,但是当前线程依旧存活
,可以通过thread指向的内存地址,访问到Thread对象,从而访问到ThreadLocalMap对象,访问到value指向的内存空间,访问到arr指向的内存空间,从而导致Java垃圾回收并不会回收int[1000000]@541这一片空间。那么随着循环多次之后,不被回收的堆空间越来越大,最后抛出java.lang.OutOfMemoryError: Java heap space。 您问:那为什么调用t.threadLocal.remove()就可以呢? 我答:这就得看remove方法里究竟做了什么了,请看:
是不是恍然大悟?来看下调用remove方法之后的内存状态:
因为remove方法将referent和value都被设置为null,所以ThreadLocal@540和Memory$TestClass@538对应的内存地址都变成不可达,Java垃圾回收自然就会回收这片内存,从而不会出现内存泄漏的错误。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200526204826669.gif#pic_center#pic_center)
参考资料:
https://blog.csdn.net/lzb348110175/article/details/105970725
http://www.iocoder.cn/JDK/ThreadLocal/
欢迎关注公众号Java技术大本营,会不定期分享BAT面试资料等福利。