ThreadLocal用法与实现原理
1. 对ThreadLocal的理解
Java API中对ThreadLocal的描述是:该类提供了线程局部 (thread-local) 变量。这些变量不同于它们的普通对应物,因为访问某个变量(通过其 get 或 set 方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。
我们先来看一个例子:
package com.threadlocal.demo;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
/**
* 数据库操作辅助类
*
* @author 小明
*
*/
public class DbSession {
private static String DRIVER = "com.mysql.jdbc.Driver"; // 驱动字符串
private static String URL = "jdbc:mysql:///test"; // 连接字符串
private static String USER = "root"; // 用户名
private static String PASSWORD = "123456"; // 密码
private static Connection connection; // 连接对象
/**
* 加载驱动
*/
static {
try {
Class.forName(DRIVER);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
/**
* 打开数据库连接
*
* @return 连接对象
*/
public static Connection openConnection() {
if (connection == null) {
try {
connection = DriverManager.getConnection(URL, USER, PASSWORD);
} catch (SQLException e) {
e.printStackTrace();
}
}
return connection;
}
/**
* 关闭连接资源
*/
public static void closeConnection() {
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
}
这是一个数据库连接管理类,在单线程中使用是没有任何问题的,但是在多线程中使用就会出现线程安全问题:这里打开和关闭连接的2个方法都没有进行同步,很可能在openConnection()方法中会多次创建connection对象。由于connection是共享变量,那么就有必要在调用connection的地方使用同步来保障线程安全,因为很可能一个线程在使用connection进行数据库操作(如读数据),而另外一个线程调用closeConnection()关闭链接。这时,我们可以使得线程同步来解决这个问题。
那么我们要实现线程同步,使用同步方法或是同步块的作法是否可行呢?我们想像一下,在线程同步时,一个线程在使用connection进行数据库操作的时候,其他线程只有等待,这就将大大影响程序执行效率。
我们实现线程同步,其主要目的是保障共享资源的线程安全,那么这儿的connection连接资源是否真的需要共享呢?事实上,这是不需要的。假如每个线程中都有一个connection变量,各个线程之间对connection变量的访问实际上是没有依赖关系的,是相互独立的,即一个线程不需要关心其他线程是否对这个connection进行了修改。
既然是这样,那我们就不使用静态成员吧,将所有变量与方法的static都去掉。这样做又会不会有什么问题呢?
这时,如果我们要使用连接对象来操作数据,就得先创建DbSession对象,然后通过对象来调用相应的建立连接、关闭连接的方法。这就又出现了一个问题:服务器压力增大,并且严重影响程序执行性能。由于在方法中需要频繁的打开和关闭数据库连接,这样不仅严重影响程序执行效率,还可能导致服务器压力巨大,这就好比我们要过一条河,每次过河前先收集原材料在河上建一座桥,过河之后又将桥拆掉,那么在建桥和拆桥的时候,所消耗的资源是非常多的。
那我们到底使用什么方法能够使得这个问题比较完美的解决呢,这就是我们的ThreadLocal类。
2. 深入理解ThreadLocal
从Java API来看,ThreadLocal主要表达了下面几种观点:
- ThreadLocal不是线程,是线程的一个局部变量,可以先简单理解为线程类的属性。
- 每个线程有自己的一个ThreadLocal,它是变量的一个副本(也称拷贝),所以修改它不影响其他线程。
- ThreadLocal在类中通常定义为静态类变量。
ThreadLocal在每个线程中对该变量会创建一个副本,即每个线程内部都会有一个该变量,且在线程内部任何地方都可以使用,线程之间互不影响,这样一来就不存在线程安全问题,也不会严重影响程序执行性能。
需要注意的是,虽然ThreadLocal能够解决上面第1节所说的问题,但是由于在每个线程中都创建了副本,所以要考虑它对资源的消耗,比如内存的占用会比不使用ThreadLocal要大,这是典型的“以空间换时间”的设计方式,而我们以前用到的synchronized同步是“以时间换空间”的设计方式。
2.1 常用方法介绍
ThreadLocal方法介绍:
返回类型 | 方法 | 说明 |
---|---|---|
T | get() | 返回此线程局部变量的当前线程副本中的值。 |
protected T | initialValue() | 返回此线程局部变量的当前线程的“初始值”,一般是用来在使用时进行重写的,它是一个延迟加载方法。 |
void | remove() | 移除此线程局部变量当前线程的值。 |
void | set(T value) | 将此线程局部变量的当前线程副本中的值设置为指定值。 |
改进后的数据库操作辅助类:
package com.threadlocal.demo;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
/**
* 数据库操作辅助类
*
* @author 小明
*
*/
public class DbSession {
private static String DRIVER = "com.mysql.jdbc.Driver"; // 驱动字符串
private static String URL = "jdbc:mysql:///test"; // 连接字符串
private static String USER = "root"; // 用户名
private static String PASSWORD = "123456"; // 密码
private static ThreadLocal<Connection> threadLocal = new ThreadLocal<Connection>(); // ThreadLocal对象
/**
* 加载驱动
*/
static {
try {
Class.forName(DRIVER);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
/**
* 打开数据库连接
*
* @return 连接对象
*/
public static Connection openConnection() {
Connection connection = threadLocal.get(); // 获取ThreadLocal中保存的连接对象
if (connection == null) {
try {
connection = DriverManager.getConnection(URL, USER, PASSWORD); // 创建连接对象
threadLocal.set(connection); // 将连接对象保存到ThreadLocal对象中
} catch (SQLException e) {
e.printStackTrace();
}
}
return connection;
}
/**
* 关闭连接资源
*/
public static void closeConnection() {
Connection connection = threadLocal.get(); // 获取ThreadLocal中保存的连接对象
if (connection != null) { // 不为空则释放资源
try {
threadLocal.set(null); // 将ThreadLocal中的连接对象置空
connection.close(); // 关闭连接对象
} catch (SQLException e) {
e.printStackTrace();
}
}
}
}
2.2 深入源码(JDK1.6)
我们先看一下get()方法:
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null)
return (T)e.value;
}
return setInitialValue();
}
先取得当前线程,然后通过getMap(t)方法获取到一个map,map的类型为ThreadLocalMap。map不为空,则获取到key-value键值对,如果获取成功,则返回value值。如果map为空,则调用setInitialValue()方法返回value。
再进一层,我们看看getMap()作了什么:
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
在getMap中,返回当前线程t中的一个成员变量threadLocals,threadLocals又是什么?
ThreadLocal.ThreadLocalMap threadLocals = null;
它实际上是一个ThreadLocalMap,是ThreadLocal类的一个静态内部类,我们继续取部分ThreadLocalMap的实现:
static class ThreadLocalMap {
private Entry[] table;
static class Entry extends WeakReference<ThreadLocal> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal k, Object v) {
super(k);
value = v;
}
}
private Entry getEntry(ThreadLocal key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
private void set(ThreadLocal key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
}
我们看到ThreadLocalMap的Entry继承了WeakReference,使用ThreadLocal作为键。通过getEntry()方法获取到key-value键值对。set()方法将key-value键值对映射保存到table数组中,key存在,则替换value,key不存在,则保存新的映射。
再继续看setInitialValue()方法:
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
initialValue()方法中返回null值,接下来仍然会去获取ThreadLocalMap对象,不为空,则设置键值对,若为空,再创建ThreadLocalMap对象:
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
接着再来看一下ThreadLocal的set()方法:
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
set()方法相对比较简单:获取当前线程的引用,获取该线程对应的map,如果map存在更新缓存值,否则创建并存储。
至此,可能部分读者已经明白了ThreadLocal是如何为每个线程创建变量的副本的:
在每个线程Thread内部有一个ThreadLocal.ThreadLocalMap类型的成员变量threadLocals,这个threadLocals就是用来存储实际的变量副本的,键(key)为当前ThreadLocal变量,值(value)为变量副本(即T类型的变量)。
初始时,在Thread里面,threadLocals为空,当通过ThreadLocal调用get()方法或者set()方法,就会对Thread类中的threadLocals进行初始化,并且以当前ThreadLocal变量为key,以ThreadLocal要保存的副本变量为value,存到threadLocals中。
然后在当前线程里面,如果要使用副本变量,就可以通过get()方法在threadLocals里面查找。
最常见的ThreadLocal使用场景是用来解决数据库连接、Session管理等问题。
2.3 示例
下面我们来看一个示例,重现一个关于日期解析的问题,在重现这个问题之前还是要先来看一下SimpleDateFormat类中的parse()方法:
public Date parse(String text, ParsePosition pos) {
// …… // 处理
calendar.clear(); // 清空所有时间字段值
// …… // 处理
parsedDate = calendar.getTime(); // 获取时间
// …… // 处理
}
这儿我只抽取出来这两条语句,其中calendar是一个Calendar对象引用,它用来储存和这个SimpleDateFormat相关的日期信息。如果SimpleDateFormat是static的,那么多个Thread之间就会共享这个SimpleDateFormat,同时也就共享这个Calendar的引用。
问题重现:
package com.threadlocal.demo;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
/**
* 日期解析示例
*
* @author 小明
*
*/
public class DateParseDemo {
public static void main(String[] args) {
// 启动线程1,解析"2015-10-1 00:00:00",休眠2秒钟
new DateParseThread("线程1", "2015-10-1 00:00:00", true, 2000).start();
// 启动线程2,解析"2012-3-8 15:37:22",不休眠
new DateParseThread("线程2", "2012-3-8 15:37:22", false, 0).start();
}
}
/**
* 日期解析线程类
*
* @author 小明
*
*/
class DateParseThread extends Thread {
private String name; // 线程名称
private String dateString; // 要解析的日期字符串
private boolean isSleep; // 是否休眠
private long sleepTime; // 休眠时长
public DateParseThread(String name, String dateString, boolean isSleep,
long sleepTime) {
super();
this.name = name;
this.dateString = dateString;
this.isSleep = isSleep;
this.sleepTime = sleepTime;
}
@Override
public void run() {
if (isSleep) {
try {
Thread.sleep(sleepTime);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Date date = DateParseUtil.parse(dateString);
System.out.println("线程:" + this.name + ",日期时间为:" + date);
}
}
/**
* 日期解析辅助类
*
* @author 小明
*
*/
class DateParseUtil {
private static final String PATTERN = "yyyy-MM-dd HH:mm:ss"; // 格式
private static SimpleDateFormat sdf; // SimpleDateFormat对象
static {
sdf = new SimpleDateFormat(PATTERN); // 创建基于给定模式的SimpleDateFormat对象
}
/**
* 日期解析
*
* @param dateString
* 待解析字符串
* @return 日期
*/
public static Date parse(String dateString) {
try {
return sdf.parse(dateString);
} catch (ParseException e) {
e.printStackTrace();
}
return null;
}
}
我们在执行前先在 calendar.clear()
和 calendar.getTime()
之间打个断点,然后使用Debug模式来执行这段代码。
线程1启动之后进入sleep(休眠)状态,线程2启动起来后卡在断点处(暂停执行),这时calendar的日期时间是:”2012-3-8 15:37:22”。当线程1从休眠中醒来后接着向下执行,当执行到断点处时,因为calendar是被共享的资源,所以它的日期时间又变为:”2015-10-1 00:00:00”。最后让两个线程断点继续执行,结果可想而知了:
线程:线程1,日期时间为:Thu Oct 01 00:00:00 CST 2015
线程:线程2,日期时间为:Thu Oct 01 00:00:00 CST 2015
在实际业务中,我们不会是在Debug模式下运行,但如果线程1调用了sdf.parse(),并且进行了calendar.clear()后还未执行calendar.getTime()的时候,线程2又调用了sdf.parse(),这时候线程2也执行了sdf.clear()方法,这样就导致线程1的calendar数据被清空了(实际上线程1,2同时被清空了);又或者当线程1执行了calendar.clear()后被挂起,这时候线程2开始调用sdf.parse()并顺利结束,这样线程1的calendar内存储的日期时间就变成了后来线程2设置的calendar的日期时间值。
那么我们怎么解决这种问题呢,最简单地就是将静态的SimpleDateFormat改为实例SimpleDateFormat,这样每个线程都会有一个自己的SimpleDateFormat实例。但使用这种方法,在高并发的情况下会大量的创建SimpleDateFormat对象以及销毁SimpleDateFormat对象,这样是非常耗费资源的。
我们就可以使用ThreadLocal来优化,将DateParseUtil代码修改如下:
/**
* 日期解析辅助类
*
* @author 小明
*
*/
class DateParseUtil {
private static final String PATTERN = "yyyy-MM-dd HH:mm:ss"; // 格式
private static ThreadLocal<SimpleDateFormat> threadLocal = new ThreadLocal<SimpleDateFormat>();
/**
* 获取SimpleDateFormat对象
*
* @return SimpleDateFormat对象
*/
public static SimpleDateFormat getSimpleDateFormat() {
SimpleDateFormat simpleDateFormat = threadLocal.get();
if (simpleDateFormat == null) {
simpleDateFormat = new SimpleDateFormat(PATTERN);
threadLocal.set(simpleDateFormat);
}
return simpleDateFormat;
}
/**
* 日期解析
*
* @param dateString
* 待解析字符串
* @return 日期
*/
public static Date parse(String dateString) {
try {
return getSimpleDateFormat().parse(dateString);
} catch (ParseException e) {
e.printStackTrace();
}
return null;
}
}
重新在Debug模式下运行,结果为:
线程:线程2,日期时间为:Thu Mar 08 15:37:22 CST 2012
线程:线程1,日期时间为:Thu Oct 01 00:00:00 CST 2015
这样,在两个线程中解析的时间就正确了。
当然,DateParseUtil类也可以修改为:
/**
* 日期解析辅助类
*
* @author 小明
*
*/
class DateParseUtil {
private static final String PATTERN = "yyyy-MM-dd HH:mm:ss"; // 格式
private static ThreadLocal<SimpleDateFormat> threadLocal = new ThreadLocal<SimpleDateFormat>() {
protected SimpleDateFormat initialValue() {
return new SimpleDateFormat(PATTERN);
};
};
/**
* 获取SimpleDateFormat对象
*
* @return SimpleDateFormat对象
*/
public static SimpleDateFormat getSimpleDateFormat() {
return threadLocal.get();
}
/**
* 日期解析
*
* @param dateString
* 待解析字符串
* @return 日期
*/
public static Date parse(String dateString) {
try {
return getSimpleDateFormat().parse(dateString);
} catch (ParseException e) {
e.printStackTrace();
}
return null;
}
}
这里创建一个ThreadLocal类变量,创建时用了一个匿名类,覆盖了initialValue()方法,主要作用是创建时初始化实例。