ThreadLocal 内部实现、应用场景和内存泄漏

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u012834750/article/details/71646700

一、什么是ThreadLocal

首先明确一个概念,那就是ThreadLocal并不是用来并发控制访问一个共同对象,而是为了给每个线程分配一个只属于该线程的变量,顾名思义它是local variable(线程局部变量)。它的功用非常简单,就是为每一个使用该变量的线程都提供一个变量值的副本,是每一个线程都可以独立地改变自己的副本,而不会和其它线程的副本冲突,实现线程间的数据隔离。从线程的角度看,就好像每一个线程都完全拥有该变量。

set和get方法是ThreadLocal类中最常用的两个方法。,接下来 我们来看下ThreadLocal的内部实现:

set方法实现源码如下:

public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }


ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

//Thread类里默认threadLocals为null
class Thread implements Runnable{
    ThreadLocal.ThreadLocalMap threadLocals = null;
}

static class ThreadLocalMap {

        static class Entry extends WeakReference<ThreadLocal<?>> {

            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
  }


void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

Thread.currentThread得到当前线程,如果当前线程存在threadLocals这个变量不为空,那么根据当前的ThreadLocal实例作为key寻找在map中位置,然后用新的value值来替换旧值。

在ThreadLocal这个类中比较引人注目的应该是ThreadLocal->ThreadLocalMap->Entry这个类。这个类继承自WeakReference。

get方法实现源码如下:

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

首先我们通过Thread.currentThread得到当前线程,然后获取当前线程的threadLocals变量,这个变量就是ThreadLocalMap类型的,如果这个变量map不为空,再获取ThreadLocalMap.Entry e,如果e不为空,则获取value值返回,否则在Map中初始化Entry,并返回初始值null。如果map为空,则创建并初始化map,并返回初始值null。


二、ThreadLocal应用场景

1、数据库连接池实现

jdbc连接数据库,如下所示:

Class.forName("com.mysql.jdbc.Driver");
java.sql.Connection conn = DriverManager.getConnection(jdbcUrl);

注意:一次Drivermanager.getConnection(jdbcurl)获得只是一个connection,并不能满足高并发情况。因为connection不是线程安全的,一个connection对应的是一个事物。

每次获得connection都需要浪费cpu资源和内存资源,是很浪费资源的。所以诞生了数据库连接池。数据库连接池实现原理如下:

pool.getConnection(),都是先从threadlocal里面拿的,如果threadlocal里面有,则用,保证线程里的多个dao操作,用的是同一个connection,以保证事务。如果新线程,则将新的connection放在threadlocal里,再get给到线程。

将connection放进threadlocal里的,以保证每个线程从连接池中获得的都是线程自己的connection。

Hibernate的数据库连接池源码实现:

 public class ConnectionPool implements IConnectionPool {  
    // 连接池配置属性  
    private DBbean dbBean;  
    private boolean isActive = false; // 连接池活动状态  
    private int contActive = 0;// 记录创建的总的连接数  

    // 空闲连接  
    private List<Connection> freeConnection = new Vector<Connection>();  
    // 活动连接  
    private List<Connection> activeConnection = new Vector<Connection>();  

 // 将线程和连接绑定,保证事务能统一执行
    private static ThreadLocal<Connection> threadLocal = new ThreadLocal<Connection>(); 

public ConnectionPool(DBbean dbBean) {  
        super();  
        this.dbBean = dbBean;  
        init();  
        cheackPool();  
    }  

    // 初始化  
    public void init() {  
        try {  
            Class.forName(dbBean.getDriverName());  
            for (int i = 0; i < dbBean.getInitConnections(); i++) {  
                Connection conn;  
                conn = newConnection();  
                // 初始化最小连接数  
                if (conn != null) {  
                    freeConnection.add(conn);  
                    contActive++;  
                }  
            }  
            isActive = true;  
        } catch (ClassNotFoundException e) {  
            e.printStackTrace();  
        } catch (SQLException e) {  
            e.printStackTrace();  
        }  
    }  

    // 获得当前连接  
    public Connection getCurrentConnecton(){  
        // 默认线程里面取  
        Connection conn = threadLocal.get();  
        if(!isValid(conn)){  
            conn = getConnection();  
        }  
        return conn;  
    }  

    // 获得连接  
    public synchronized Connection getConnection() {  
        Connection conn = null;  
        try {  
            // 判断是否超过最大连接数限制  
            if(contActive < this.dbBean.getMaxActiveConnections()){  
                if (freeConnection.size() > 0) {  
                    conn = freeConnection.get(0);  
                    if (conn != null) {  
                        threadLocal.set(conn);  
                    }  
                    freeConnection.remove(0);  
                } else {  
                    conn = newConnection();  
                }  

            }else{  
                // 继续获得连接,直到从新获得连接  
                wait(this.dbBean.getConnTimeOut());  
                conn = getConnection();  
            }  
            if (isValid(conn)) {  
                activeConnection.add(conn);  
                contActive ++;  
            }  
        } catch (SQLException e) {  
            e.printStackTrace();  
        } catch (ClassNotFoundException e) {  
            e.printStackTrace();  
        } catch (InterruptedException e) {  
            e.printStackTrace();  
        }  
        return conn;  
    }  

    // 获得新连接  
    private synchronized Connection newConnection()  
            throws ClassNotFoundException, SQLException {  
        Connection conn = null;  
        if (dbBean != null) {  
            Class.forName(dbBean.getDriverName());  
            conn = DriverManager.getConnection(dbBean.getUrl(),  
                    dbBean.getUserName(), dbBean.getPassword());  
        }  
        return conn;  
    }  

    // 释放连接  
    public synchronized void releaseConn(Connection conn) throws SQLException {  
        if (isValid(conn)&& !(freeConnection.size() > dbBean.getMaxConnections())) {  
            freeConnection.add(conn);  
            activeConnection.remove(conn);  
            contActive --;  
            threadLocal.remove();  
            // 唤醒所有正待等待的线程,去抢连接  
            notifyAll();  
        }  
    }  

    // 判断连接是否可用  
    private boolean isValid(Connection conn) {  
        try {  
            if (conn == null || conn.isClosed()) {  
                return false;  
            }  
        } catch (SQLException e) {  
            e.printStackTrace();  
        }  
        return true;  
    }  

    // 销毁连接池  
    public synchronized void destroy() {  
        for (Connection conn : freeConnection) {  
            try {  
                if (isValid(conn)) {  
                    conn.close();  
                }  
            } catch (SQLException e) {  
                e.printStackTrace();  
            }  
        }  
        for (Connection conn : activeConnection) {  
            try {  
                if (isValid(conn)) {  
                    conn.close();  
                }  
            } catch (SQLException e) {  
                e.printStackTrace();  
            }  
        }  
        isActive = false;  
        contActive = 0;  
    }  

    // 连接池状态  
    @Override  
    public boolean isActive() {  
        return isActive;  
    }  

    // 定时检查连接池情况  
    @Override  
    public void cheackPool() {  
        if(dbBean.isCheakPool()){  
            new Timer().schedule(new TimerTask() {  
            @Override  
            public void run() {  
            // 1.对线程里面的连接状态  
            // 2.连接池最小 最大连接数  
            // 3.其他状态进行检查,因为这里还需要写几个线程管理的类,暂时就不添加了  
            System.out.println("空线池连接数:"+freeConnection.size());  
            System.out.println("活动连接数::"+activeConnection.size());  
            System.out.println("总的连接数:"+contActive);  
                }  
            },dbBean.getLazyCheck(),dbBean.getPeriodCheck());  
        }  
    }  
}  

2、有时候ThreadLocal也可以用来避免一些参数传递,通过ThreadLocal来访问对象。

比如一个方法调用另一个方法时传入了8个参数,通过逐层调用到第N个方法,传入了其中一个参数,此时最后一个方法需要增加一个参数,第一个方法变成9个参数是自然的,但是这个时候,相关的方法都会受到牵连,使得代码变得臃肿不堪。这时候就可以将要添加的参数设置成线程本地变量,来避免参数传递。

上面提到的是ThreadLocal一种亡羊补牢的用途,不过也不是特别推荐使用的方式,它还有一些类似的方式用来使用,就是在框架级别有很多动态调用,调用过程中需要满足一些协议,虽然协议我们会尽量的通用,而很多扩展的参数在定义协议时是不容易考虑完全的以及版本也是随时在升级的,但是在框架扩展时也需要满足接口的通用性和向下兼容,而一些扩展的内容我们就需要ThreadLocal来做方便简单的支持。

简单来说,ThreadLocal是将一些复杂的系统扩展变成了简单定义,使得相关参数牵连的部分变得非常容易。

3、在某些情况下提升性能和安全。

用SimpleDateFormat这个对象,进行日期格式化。因为创建这个对象本身很费时的,而且我们也知道SimpleDateFormat本身不是线程安全的,也不能缓存一个共享的SimpleDateFormat实例,为此我们想到使用ThreadLocal来给每个线程缓存一个SimpleDateFormat实例,提高性能。同时因为每个Servlet会用到不同pattern的时间格式化类,所以我们对应每一种pattern生成了一个ThreadLocal实例。

public interface DateTimeFormat {
        String DATE_PATTERN = "yyyy-MM-dd";
        ThreadLocal<DateFormat> DATE_FORMAT = ThreadLocal.withInitial(() -> {
            return new SimpleDateFormat("yyyy-MM-dd");
        });
        String TIME_PATTERN = "HH:mm:ss";
        ThreadLocal<DateFormat> TIME_FORMAT = ThreadLocal.withInitial(() -> {
            return new SimpleDateFormat("HH:mm:ss");
        });
        String DATETIME_PATTERN = "yyyy-MM-dd HH:mm:ss";
        ThreadLocal<DateFormat> DATE_TIME_FORMAT = ThreadLocal.withInitial(() -> {
            return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        });
    }

为什么SimpleDateFormat不安全,可以参考此篇博文:

SimpleDateFormat线程不安全及解决办法

假如我们把SimpleDateFormat定义成static成员变量,那么多个thread之间会共享这个sdf对象, 所以Calendar对象也会共享。
假定线程A和线程B都进入了parse(text, pos) 方法, 线程B执行到calendar.clear()后,线程A执行到calendar.getTime(), 那么就会有问题。

如果不用static修饰,将SimpleDateFormat定义成局部变量:
每调用一次方法就会创建一个SimpleDateFormat对象,方法结束又要作为垃圾回收。加锁性能较差,每次都要等待锁释放后其他线程才能进入。那么最好的办法就是:使用ThreadLocal: 每个线程都将拥有自己的SimpleDateFormat对象副本。

附-SimpleDateFormat关键源码:

public class SimpleDateFormat extends DateFormat {  

    public Date parse(String text, ParsePosition pos){  
        calendar.clear(); // Clears all the time fields  
        // other logic ...  
        Date parsedDate = calendar.getTime();  
    }  
}  

abstract class DateFormat{  
    // other logic ...  
    protected Calendar calendar;  
    public Date parse(String source) throws ParseException{  
        ParsePosition pos = new ParsePosition(0);  
        Date result = parse(source, pos);  
        if (pos.index == 0)  
            throw new ParseException("Unparseable date: \"" + source + "\"" ,  
                pos.errorIndex);  
        return result;  
    }  
}  

三、内存泄漏问题

在上面提到过,每个thread中都存在一个map, map的类型是ThreadLocal.ThreadLocalMap. Map中的key为一个threadlocal实例. 这个Map的确使用了弱引用,不过弱引用只是针对key. 每个key都弱引用指向threadlocal. 当把threadlocal实例置为null以后,没有任何强引用指向threadlocal实例,所以threadlocal将会被gc回收. 但是,我们的value却不能回收,因为存在一条从current thread连接过来的强引用. 只有当前thread结束以后, current thread就不会存在栈中,强引用断开, Current Thread, Map, value将全部被GC回收。

所以得出一个结论就是只要这个线程对象被gc回收,就不会出现内存泄露,但在threadLocal设为null和线程结束这段时间不会被回收的,就发生了我们认为的内存泄露。其实这是一个对概念理解的不一致,也没什么好争论的。最要命的是线程对象不被回收的情况,这就发生了真正意义上的内存泄露。比如使用线程池的时候,线程结束是不会销毁的,会再次使用的。就可能出现内存泄露。

阅读更多
换一批

没有更多推荐了,返回首页