并发编程--ThreadLocal


上节我们简单使用过ThreadLocal,本节我们讨论ThreadLocal实现的原理,如何做到每个线程有一份属于自己的数据备份

ThreadLocal类图

在这里插入图片描述
从类图我们可以很清楚的看到 ThreadLocal 的内部结构 (JDK8_201)

  • 1个无参构造函数
  • 4个public函数 其中withInitial()为类级别的静态方法
  • 2个内部类 ThreadLocalMap 与 SuppliedThreadLocal

ThreadLocal

接下来我们重点分析 get() set() 与remove()方法

get()

public T get() {
    Thread t = Thread.currentThread(); //获取当前线程
    ThreadLocalMap map = getMap(t); // 取得当前线程的ThreadLocalMap实例,就是Thread类的属性 threadLocals
    if (map != null) {//map为空调用初始化方法
    	// 这里可以看到最后操作的就是ThreadLocalMap 实例
        ThreadLocalMap.Entry e = map.getEntry(this);//通过ThreadLocalMap获取存储的值,具体后续分析
        if (e != null) {//环形数组中没有找到对应数据也会调用初始化方法
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    //如果从ThreadLocalMap未找到 获取初始值
    return setInitialValue();
}

ThreadLocalMap getMap(Thread t) {
    //获取Thread类的属性 threadLocals 类型为ThreadLocal.ThreadLocalMap
    return t.threadLocals; //threadLocals中保存线程的所有的线程本地变量
}

private T setInitialValue() {
    T value = initialValue();//调用initialValue方法获取,此方法默认返回null,可重写
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);//设置初始值 操作的是ThreadLocalMap 
    else
        createMap(t, value);//初始化ThreadLocalMap
    return value; //返回初始值
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);//调用ThreadLocalMap构造函数 参考下方构造函数
}

set()

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value); //操作的是ThreadLocalMap
    else
        createMap(t, value);
}

remove()

public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);//操作的是ThreadLocalMap
     }

withInitial()

get() 方法我们可以看到,当ThreadLocal实例第一次调用get()时,如果ThreadLocalMap没有初始化或者环形数组中没有找到数据时,最终会调用ThreadLocal的initialValue()方法

所以我们可以重写initialValue提供初始值

ThreadLocal<Integer> threadLocal = new ThreadLocal(){
    @Override
    protected Object initialValue() {
        return 0;
    }
};

withInitial()方法就是这么做的

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);//返回一个SuppliedThreadLocal对象 SuppliedThreadLocal是一个内部类
}

//SuppliedThreadLocal对象内部继承ThreadLocal并重写了initialValue()方法
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

    private final Supplier<? extends T> supplier;

    SuppliedThreadLocal(Supplier<? extends T> supplier) {
        this.supplier = Objects.requireNonNull(supplier);//不能为Null
    }

    @Override
    protected T initialValue() {
    	/**
         * 调用supplier接口的get()方法
         * supplier属性通过构造函数传入不能为空
         * Supplier是一个FunctionalInterface 支持lambada表达式
         */
        return supplier.get(); 
    }
}

所以我们还可以这样提供初始值

ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(()-> 1);

ThreadLocalMap

从上面可以看到对ThreadLocal的操作都会作用到ThreadLocalMap上,我们具体分析ThreadLocalMap
ThreadLocalMap提供了一种为ThreadLocal定制的高效实现,并且自带一种基于弱引用的垃圾清理机制。

类图

在这里插入图片描述

从图中我们可以看到ThreadLocalMap中有一个Entry数组,数组里面存放的是Entry对象,
Entry对象可以放两个属性一个是ThreadLocal对象(弱引用),另一个就是要保存的线程本地变量
ThreadLocalMap可以简单地将它的key视作ThreadLocal,value为代码中放入的值(实际上key并不是ThreadLocal本身,而是它的一个弱引用)

ThreadLocalMap构造函数

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];//初始化一个长度为16的Entry数组
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);//定位元素的下标
    table[i] = new Entry(firstKey, firstValue);//把值放到数组对应位置
    size = 1;
    setThreshold(INITIAL_CAPACITY);//设置阀值为数组长度2/3,超过这个值就要扩容数组 再Hash
}

可以看到在第一次get或者set()的时候会初始化ThreadLocalMap

另外我们要注意的是ThreadLocalMap内部的Entry数组是环形数组

在这里插入图片描述

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0); //环形数组的下一个索引
}

private static int prevIndex(int i, int len) {
    return ((i - 1 >= 0) ? i - 1 : len - 1); //环形数组的上一个索引
}

set()

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);//定位元素下标

    for (Entry e = tab[i];
         e != null; //循环结束条件1 元素为空
         e = tab[i = nextIndex(i, len)]) { //从位置i一直向后线性探测寻找,只要元素不为空
        ThreadLocal<?> k = e.get();
		
        if (k == key) { //元素都不为空,找到对应的key,更新value 循环结束条件2
            e.value = value;
            return;
        }
		
        if (k == null) { //元素都不为空, 弱引用失效,调用replaceStaleEntry替换旧值 循环结束3
            replaceStaleEntry(key, value, i);
            return;
        }
    }
  
    tab[i] = new Entry(key, value);//循环条件1不满足,有空位值,把值放到第一个空位置上
    int sz = ++size; //数量加1
    if (!cleanSomeSlots(i, sz) && sz >= threshold) //未清理数据&数量大于阀值
        rehash();
}

//替换失效位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) { 
    Entry[] tab = table;
    int len = tab.length;
    Entry e;
    
    int slotToExpunge = staleSlot; //staleSlot就是弱引用失效的位置索引
    
    //从要删除的位置向前找,找到最前一个弱引用失效的位置,把并索引赋给slotToExpunge 
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null; //此处为空会结束循环
         i = prevIndex(i, len))
        if (e.get() == null) //失效位置
            slotToExpunge = i; //赋值,备注:如果都不为空并且没有失效的位置,i就是当前位置
	
    for (int i = nextIndex(staleSlot, len); //向后线性探测
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        if (k == key) { //找到有效的key
            e.value = value; //更新key对应的value
            //下面两句交换失效位置与有效key的位置元素
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
			
			/*
             * 如果在整个扫描过程中(包括函数一开始的向前扫描与i之前的向后扫描)
             * 找到了之前的失效位置则以那个位置作为清理的起点,
             * 否则则以当前的i作为清理起点,此时的i为交换位置后的失效位置索引
             */
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
                
            // 从slotToExpunge开始做一次删除失效元素,再做一次清理失效位置
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        
        // 如果当前的位置已经失效,并且向前扫描过程中没有失效位置,则更新slotToExpunge为当前位置索引
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }
    
	// 如果key在table中不存在,则在失效位置放一个替换掉失效元素
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);
	
	// 在探测过程中如果发现有任何失效位置,则做一次清理
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

//删除失效元素
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    
    tab[staleSlot].value = null;//失效位置value置为空
    tab[staleSlot] = null; //失效位置为空
    size--;

    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) { //碰到失效位置
            e.value = null;//失效位置value置为空
            tab[i] = null;//失效位置为空
            size--;
        } else { //没有失效的位置
            int h = k.threadLocalHashCode & (len - 1); //重新计算索引位置
            if (h != i) { //如果与当前位置不相同
                tab[i] = null;把当前位置置为空

                while (tab[h] != null)//从h开始寻找下一个空位置
                    h = nextIndex(h, len);
                tab[h] = e; 把当前元素放到空位置
            }
        }
    }
    //返回失效位置后第一个空位置索引
    return i;
}

//清理失效位置 没有要清理的数据时返回false
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len); //从i开始向后连续探测
        Entry e = tab[i];
        if (e != null && e.get() == null) { //是否有失效的位置元素
            n = len; //有失效元素重置n
            removed = true;
            i = expungeStaleEntry(i); //返回i后的第一个空索引位置
        }
    } while ( (n >>>= 1) != 0); //n 控制扫描次数 正常情况下如果log^n次扫描没有发现失效位置,函数就结束了
    return removed;
}

//再Hash
private void rehash() {
    expungeStaleEntries();//全面清理
    // size会变小,因为清理会使size减少
    if (size >= threshold - threshold / 4) //size只要大于3/4阀值,也就是len/2
        resize();//扩容
}

//全面清理
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        //找到所有的失效位置,调用删除失效元素
        if (e != null && e.get() == null) 
            expungeStaleEntry(j);
    }
}

//扩容
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2; //扩容2倍
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null) //从h开始寻找下一个空位置
                    h = nextIndex(h, newLen);
                newTab[h] = e;//把值放到空位置
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

set()方法的过程

  • 探测过程中元素都有效,并且顺利找到key所在的位置,直接替换即可
  • 探测过程中发现有失效位置,调用replaceStaleEntry,效果是最终一定会把key和value放在这个位置,并且会尽可能清理无效元素
    • 在replaceStaleEntry过程中,如果找到了key,value替换新值,把此元素与失效元素交换位置
    • 在replaceStaleEntry过程中,没有找到key,直接在失效位置替换为新entry
  • 探测没有发现key,有空位值,把值放到第一个空位置上,这也是线性探测法的一部分。做一次清理无效位置,如果没清理出去key,并且当前table大小已经超过阈值了,则做一次rehash,rehash函数会调用一次全量清理方法也即expungeStaleEntries,如果完了之后table大小超过了threshold - threshold / 4,则进行2倍扩容

get()

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);//定位数组标
    Entry e = table[i];
    if (e != null && e.get() == key) //元素未失效 key相同 返回数组当前位置元素
        return e;
    else //能进到这里说明此位置失效或者为空
        return getEntryAfterMiss(key, i, e); //未找到继续线性探测
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
	
	//只要不为空继续向后查找
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key) //找到相同key 返回
            return e;
        if (k == null) //位置失效
            expungeStaleEntry(i); //删除失效位置 参考set()方法中的解释
        else
            i = nextIndex(i, len); //下一个索引
        e = tab[i];
    }
    return null;
}

remove()

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear(); //断开引用
            expungeStaleEntry(i);//删除失效位置
            return;
        }
    }
}

//e.clear其实调用了父类 Reference 的方法
public void clear() {
 	//这个referent就是entry的key,也就是ThreadLocal实例 
    this.referent = null;
}

参考资料

ThreadLocal源码解读
Java ThreadLocalMap 源码解析

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值