ThreadLocal源码剖析

引言

ThreadLocal应该算是一个面试中问的比较多的类,这个类主要用途在于只使用一个对象的引用,便能够获取在多个线程中储存变量的副本,下面,便来详细剖析下这个类的实现原理,及其中需要注意到的一些地方。

对象的创建初始化

	private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode =
            new AtomicInteger();


    private static final int HASH_INCREMENT = 0x61c88647;


    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

从上面可以看到,threadLocalHashCode保存了当前对象的hashcode,而hashcode是由nextHashCode()方法获取的,每次在原始值上增加了一个HASH_INCREMENT,之所以使用这个值是因为这样做能减少散列表的冲突。那么这个hashcode在哪里使用呢,请看接下来的代码。

   public void set(T value) {
       Thread t = Thread.currentThread();
       ThreadLocalMap map = getMap(t);
       if (map != null)
           map.set(this, value);
       else
           createMap(t, value);
   }
   
   void createMap(Thread t, T firstValue) {
       t.threadLocals = new ThreadLocalMap(this, firstValue);
   }
   
   ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
           table = new Entry[INITIAL_CAPACITY];
           int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
           table[i] = new Entry(firstKey, firstValue);
           size = 1;
           setThreshold(INITIAL_CAPACITY);
   }
   
   ThreadLocalMap getMap(Thread t) {
       return t.threadLocals;
   }

源码解析

如图,在调用ThreadLocal里set()方法时候,如果当前Thread内部的threadLocals对象没有创建,则会new一个ThreadLocalMap对象,然后将当前ThreadLocal对象的内部的hashcode,对map的大小取余,然后放入map中。这样一看似乎很让人疑惑,既然一个线程中只有一个ThreadLocal的值,为什么非要用map进行存储呢,答案是,在一个线程中,可以new多个ThreadLocal对象,用它们保存不同的值,既然如此,为了快速找到某个ThreadLocal对象内存储的值,自然就使用散列表降低时间复杂度了。
接下来来看看内部类ThreadLocalMap的结构

	static class Entry extends WeakReference<ThreadLocal<?>> {
           /** The value associated with this ThreadLocal. */
           Object value;

           Entry(ThreadLocal<?> k, Object v) {
               super(k);
               value = v;
           }
       }

       private static final int INITIAL_CAPACITY = 16;

       private Entry[] table;

       private int size = 0;

       private int threshold; // Default to 0

       private void setThreshold(int len) {
           threshold = len * 2 / 3;
       }

       private static int nextIndex(int i, int len) {
           return ((i + 1 < len) ? i + 1 : 0);
       }

       private static int prevIndex(int i, int len) {
           return ((i - 1 >= 0) ? i - 1 : len - 1);
       }

如图,此类继承了WeakReference类,也就是弱引用,它的特点是在JVM进行GC时,会专门有一条守护线程对这个类及其子类进行回收,对于弱引用的对象,每次GC当其他地方没有对它的引用时,无论内存空间足够充足,都会被进行回收。这里把ThreadLocal对象作为Entity对象的key,而set的值作为value,key弱引用对象,则进行垃圾回收之时,就会出现Entity对象存在,而内部key为null的结果,做出这种设计的主要原因还是为了帮助gc。
而从内部方法可以看出,保存value的是一个数组,且在数组内进行下标遍历之时采用的一个循环操作。其中threshold为下次扩容时的大小,即为数组size * factor,默认初始化为16,扩容因子为3分之2.
接下来看看private的set方法,看看具体如何实现的。

	private void set(ThreadLocal<?> key, Object value) {
         
         Entry[] tab = table;
         int len = tab.length;
         // 当前threadLocal在数组中应该处在的位置
         int i = key.threadLocalHashCode & (len-1);

         // 当前元素不为null时候进行遍历
         for (Entry e = tab[i];
              e != null;
              e = tab[i = nextIndex(i, len)]) {
             ThreadLocal<?> k = e.get();

             // 找到了直接设置当前值,返回
             if (k == key) {
                 e.value = value;
                 return;
             }

             // 若key为null,说明当前线程内有其他的ThreadValue对象的值被垃圾回收了,已无外部其他途径能访问到该value
             // 对key及value进行替换
             if (k == null) {
                 replaceStaleEntry(key, value, i);
                 return;
             }
         }

         // 这个时候i为数组中为null的位置,放入
         tab[i] = new Entry(key, value);
         int sz = ++size;
         // 清理被垃圾回收了key的entity对象,若没有清理成功,且size超过了最小扩容大小
         // 进行扩容
         if (!cleanSomeSlots(i, sz) && sz >= threshold)
             rehash();
     }

如图,本方法主要是以hash值进行散列表内的定位,若是发生了hash冲突,但冲突的ThreadLocal对象的key已经被垃圾回收时,对entity对象的key及value进行替换,这种情况举个例子,在一个方法里声明了一个局部变量ThreadLocal,进行了值的设定,当方法执行完成后,ThreadLocal变量已经无法进行访问,此时唯有ThreadLocalMap里存有对它的弱引用,但此时内部value已经无法通过外界进行访问了,则gc时会被回收,防止内存泄露。若是已经设置过值了,则直接重新设为最新值就行。
若是跳出循环,则在寻找过程中未有GC对象,且没有set过值,当前i值为hash位置后第一个空槽,在这里放入对象。
之后又一步检测,设计到了cleanSomeSlots方法,这个方法主要是清除被回收了key的entity对象,返回值为是否被回收,若是没有回收,且当前大小达到了扩容容量,则进行重新hash扩容。下面来看看这两个方法的具体实现。

private boolean cleanSomeSlots(int i, int n) {
          boolean removed = false;
          Entry[] tab = table;
          int len = tab.length;
          // 进行lgn次的扫描,清理一些零散区段的key被gc回收的entity
          // 若扫描到了,则n又会被重置为数组长度,继续扫描下去,直到回收完成
          do {
              i = nextIndex(i, len);
              Entry e = tab[i];
              if (e != null && e.get() == null) {
                  n = len;
                  removed = true;
                  i = expungeStaleEntry(i);
              }
          } while ( (n >>>= 1) != 0);
          // 返回是否进行了entity的移除
          return removed;
      }
      
      private int expungeStaleEntry(int staleSlot) {
          Entry[] tab = table;
          int len = tab.length;

          // 把当前entity位置设为null帮助gc
          tab[staleSlot].value = null;
          tab[staleSlot] = null;
          size--;

          Entry e;
          int i;
          // 一直遍历到一个为null的槽
          for (i = nextIndex(staleSlot, len);
               (e = tab[i]) != null;
               i = nextIndex(i, len)) {
              // 遍历过程中,若遇见了其他key被回收的entity,设为null,腾出空位并帮助垃圾回收
              ThreadLocal<?> k = e.get();
              if (k == null) {
                  e.value = null;
                  tab[i] = null;
                  size--;
              } else {
                  // 若为存在key的实体,则计算当前hash值
                  int h = k.threadLocalHashCode & (len - 1);
                  // 如果该实体不是保存在其hash值应放的位置,则从应放的位置开始,寻找到下一个空位,避免get或set改值时,前面有key为null的entity被回收了,导致get获取到null值或者set对值进行重复设置
                  // 并将该实体挪到其中
                  if (h != i) {
                      tab[i] = null;
                      while (tab[h] != null)
                          h = nextIndex(h, len);
                      tab[h] = e;
                  }
              }
          }
          // 返回下标为i的null槽
          return i;
      }

cleanSomeSlots()方法对底层数组进行了至少lgn次的扫描来进行key为null的entity的回收,key为null情况之前也说过,是外部已经没有那个key,也就是当前线程内ThreadLocal对象的引用的时候,会在gc时进行回收,n为数组长度,之所以说是至少,是因为若是回收成功了,会重置循环次数。这里涉及到了一个expungeStaleEntry()方法,这个方法参数为当前key为null的entity下标,方法内部对于被回收的entity对象的槽进行置null处理,并对中途碰到的没有回收key的非空槽进行了重新hash处理,因为不重新hash,很可能会当有key为null的entity被回收了,然后槽被置为null,这个时候get,set方法会出现问题。打个比方,若当前ThreadLocal对象的hashcode定位在数组中下标为1,但是下标为1,2的位置都被占用了,则只能放到3中,后来下标1,2的槽都被进行了回收,置为null,这个时候调用get(),就会出现1为null的情况,便会返回null值,调用set(),则会出现槽1,3中各有一个值的情况,很明显是错的。
最后方法返回了一个为null的槽,然后获取其下一个位置再次循环进行清除。
接下来让我们看看rehash()方法。源码如下

	  private void rehash() {
          expungeStaleEntries();

          // Use lower threshold for doubling to avoid hysteresis
          if (size >= threshold - threshold / 4)
              resize();
      }

      private void resize() {
          Entry[] oldTab = table;
          int oldLen = oldTab.length;
          // 双倍扩容
          int newLen = oldLen * 2;
          Entry[] newTab = new Entry[newLen];
          int count = 0;

          for (int j = 0; j < oldLen; ++j) {
              Entry e = oldTab[j];
              if (e != null) {
                  // 保险起见,再次判定,进行key被GC的entity的清除
                  ThreadLocal<?> k = e.get();
                  if (k == null) {
                      e.value = null; // Help the GC
                  } else {
                      // 拉链法放置
                      int h = k.threadLocalHashCode & (newLen - 1);
                      while (newTab[h] != null)
                          h = nextIndex(h, newLen);
                      newTab[h] = e;
                      count++;
                  }
              }
          }

          setThreshold(newLen);
          size = count;
          table = newTab;
      }


      private void expungeStaleEntries() {
          Entry[] tab = table;
          int len = tab.length;
          // 循环整个数组,清除所有被GC的entity
          for (int j = 0; j < len; j++) {
              Entry e = tab[j];
              if (e != null && e.get() == null)
                  expungeStaleEntry(j);
          }
      }

首先在rehash()方法中调用了expungeStaleEntries()方法,点进去可以看到,这个方法作用很明显,循环这个数组,把所有的被回收了key的entity清除。然后调用了resize() 方法,这个方法对当前数组进行了双倍扩容,然后使用了setThreshold()方法,threshold值设置为下次扩容时的size大小,也就是数组大小乘上扩容因子。
在扩容中又保险起见进行了对被回收的entity置空的操作,并且使用拉链法,也就是冲突了则放入下一个null槽的方法,对原先整个数组进行了重新hash。
实际上因为map本身是附属于一个thread,对map的所有操作均不会存在线程安全问题,唯一可能有问题的ThreadLocal的hashcode也使用了CAS进行运算,因此我们完全可以安心使用它,无需担心线程安全问题。
set()方法的整个流程我们已经看完了,接下来看看get方法是如何实现了,代码如下。

	public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

在get()方法中,若是还没有设置值,则会调用setInitialValue()方法,若没初始化map,则会初始化ThreadLocalMap,这里面调用了initialValue()方法,这个方法本身是一个空方法,作用是我们可以重写这个方法,然后若不调用set方法,直接使用get就能够从threadLocal中拿出我们的初始值,当然以后调用了set方法也会将该值覆盖。这里面有一个getEntity()方法,让我们看下它是如何获取到当前线程的ThreadLocal对象保存的值。
代码如下

		private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
        
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

其中从传入的当前线程中取出了ThreadLocalMap,然后计算当前ThreadLocal对象的hashCode,如果没有发生冲突,则直接返回,否则调用getEntryAfterMiss()方法进行遍历。
getEntryAfterMiss()中一直遍历数组,直到槽为null或找到停止,其中又调用expungeStaleEntry方法对遍历中的遇到的被回收了的Entity进行清除,已经冲突hashcode节点进行重新hash处理。这里大家会有一个问题,如果get的节点正好冲突了,然后被重新hash处理了,那么是不是找不到了。
实际上,我们当前找的节点本来就是被冲突的节点,并且扫过的节点都是实际存在,未被gc的节点,因此当发生gc后,当前循环所走的后面的节点必然有我们要找的节点。
最后看下remove()方法。

private void remove(ThreadLocal<?> key) {
          Entry[] tab = table;
          int len = tab.length;
          int i = key.threadLocalHashCode & (len-1);
          for (Entry e = tab[i];
               e != null;
               e = tab[i = nextIndex(i, len)]) {
              if (e.get() == key) {
                  e.clear();
                  expungeStaleEntry(i);
                  return;
              }
          }
      }

方法很简单,也就是一路找到需要remove的节点,遍历路上顺带进行被gc节点的清除而已。

总结

  1. 对于ThreadLocal来说,它的本质是依赖于当前线程内的一个内部对象ThreadLocalMap来进行值的保存,因此不同线程内有不同的map,并且没有任何交集,不会有线程安全问题。
  2. 在set,get,remove方法中都有涉及到被GC节点的清除工作,所以基本上不可能会内存泄露严重。
  3. 之所以使用map进行值的保存,是因为考虑到一个线程中可能会使用多可ThreadLocal实例进行值的保存。
  4. map使用的散列方法是在冲突后,寻找后面第一个为null的槽,进行放入。
  5. map的生命周期与thread相绑定,因此对于内部保存的所有节点的value及key,都具有引用,但是key是弱引用对象,若在外界已不再持有key,即ThreadLocal对象,则会在下次gc时被垃圾回收掉,但是由于entity及内部value还存在,所以需要在代码中做额外的处理。
  6. 除非是在方法的栈中创造大量的ThreadLocal局部变量进行值的设置,否则这些回收相关代码几乎毫无意义。(会有这么傻的人吗)
  7. ThreadLocal基本是一个可以安心使用的方法。(除非用的人脑子太清奇,本来即使在局部方法内创建ThreadLocal,调用了的get,set方法也会进行部分垃圾回收,所以就是有这么做的人,也顶天就泄露一点点内存,唯一问题在于快进行垃圾回收之前大量调用此局部方法,之后再也不使用ThreadLocal,则会造成比较严重的内存泄露)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值