深入理解ConcurrentHashMap1.7源码

1. 概述

HashMap在我们的日常生活中使用很多,但是它不是线程安全的。我们可以使用HashTable来代替,主要实现方式是在方法中加入synchronized,所以效率也比较低。因此,对于键值对,我们可以尝试使用ConcurrentHashMap来解决线程安全的问题。

ConcurrentHashMap 1.7版本是采用的数组+分段锁的方式来实现的,如下图所示:
在这里插入图片描述

2. 成员变量

    static final class HashEntry<K,V> {
        final int hash;  // 哈希值
        final K key;  // key
        volatile V value;  // value
        volatile HashEntry<K,V> next;  // 下一个节点
    }

和HashMap1.7一样,对于出现哈希冲突的键值对,也是采用链表的方式链接起来。因此ConcurrentHashMap也定义了个节点类HashEntry。

我们刚刚提到,ConcurrentHashMap1.7采用分段锁的方式,我们接下来就看看段Segment的具体定义:

    // Segment静态内部类,继承自ReentrantLock
    static final class Segment<K,V> extends ReentrantLock implements Serializable {

        // HashEntry节点类数组
        transient volatile HashEntry<K,V>[] table;

        // Segment内节点总数
        transient int count;

        // 修改次数,线程不安全的时候,启用fail-fast机制
        transient int modCount;

        // 阈值
        transient int threshold;

        // 负载因子
        final float loadFactor;

从代码中可以看到,Segment是继承自ReentrantLock的,需要完成锁的一些操作。其它的成员变量就和普通的HashMap没有什么两样,也是拥有一个节点数组。

我们接下来再看下ConcurrentHashMap的成员变量:

    // 默认的初始容量
    static final int DEFAULT_INITIAL_CAPACITY = 16;

    // 默认负载因子
    static final float DEFAULT_LOAD_FACTOR = 0.75f;

    // 默认的并发数
    static final int DEFAULT_CONCURRENCY_LEVEL = 16;

    // 最大容量
    static final int MAXIMUM_CAPACITY = 1 << 30;

    // 每个Segment中的数组最小容量
    static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

    // Segment最大数目
    static final int MAX_SEGMENTS = 1 << 16; // slightly conservative

    // 计算size的最大重试次数
    static final int RETRIES_BEFORE_LOCK = 2;

    // 用于计算第几个Segment的掩码值
    final int segmentMask;

    // segment偏移量
    final int segmentShift;

    // Segment数组
    final Segment<K,V>[] segments;

在ConcurrentHashMap的成员变量中,我们可以看到是只包含Segment数组的,每个Segment内部又各自包含HashEntry数组,也就是ConcurrentHashMap先将所有的键值对分段,再分具体的桶。

3. 构造方法

    // 3个参数的构造方法
    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        // 参数不合法
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // 根据并发度来具体创建多少个Segment
        int sshift = 0; // Segment位偏移
        int ssize = 1;  // Segment的个数
        // 保证Segment的个数必须为2的n次幂
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;  // 左移,不断乘2
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;  // c是每个Segment分配到的大小
        if (c * ssize < initialCapacity)  
            ++c;
        // 保证每个Segment中桶的个数也必须为2的n次幂
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)  
            cap <<= 1;
        // 创建Segment s0以及其中的HashEntry数组,用来作为其他Segment的样例
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // CAS设置
        this.segments = ss;  
    }

    // 2个参数的构造方法
    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
    }

    // 1个参数的构造方法
    public ConcurrentHashMap(int initialCapacity) {
        this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    // 无参构造方法
    public ConcurrentHashMap() {
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    // Map迁移
    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
                      DEFAULT_INITIAL_CAPACITY),
             DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
        putAll(m);
    }

从参数最多的构造方法中我们可以看到,是根据并发度的大小来设定Segment的数量,再将initCapacity均摊到每个Segment中。此外,在创建的时候并没有一次性创建出所有Segment和其中的HashEntry数组,而是采用懒加载的方式,只创建了第一个Segment和其中的数组作为样例,当后期访问到再按照第一个Segment为样例进行创建。

4. put方法

    // put方法
    public V put(K key, V value) {
        Segment<K,V> s;
        // value不能为null
        if (value == null)
            throw new NullPointerException();
        // 计算hash值
        int hash = hash(key.hashCode());
        // 计算出要查询的key所在的segment
        int j = (hash >>> segmentShift) & segmentMask;
        // 如果要查询的Segment还没有创建,就调用ensureSegment创建
        if ((s = (Segment<K,V>)UNSAFE.getObjec
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            s = ensureSegment(j);
        // 将新的键值对插入到Segment中
        return s.put(key, hash, value, false);
    }
    // 确保被访问的段已经创建
    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;  // 获取所有的Segment
        long u = (k << SSHIFT) + SBASE; // 段偏移量
        Segment<K,V> seg;
        // 如果查询的第u个段没有被创建
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            Segment<K,V> proto = ss[0]; // 取出第0个Segment作为模版
            int cap = proto.table.length;  // 第0个段的HashEntry容量
            float lf = proto.loadFactor;  // 负载因子
            int threshold = (int)(cap * lf);  // 阈值
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];  // 仿照第一个段的hashEntry数组大小创建
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // 再次检查是否没创建
                // 创建第u个段
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    // CAS将第u个段设置到Segment数组中
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

    // 往Segment的HashEntry数组中添加键值对
    final V put(K key, int hash, V value, boolean onlyIfAbsent) {
        // 尝试获得锁
        HashEntry<K,V> node = tryLock() ? null :
            scanAndLockForPut(key, hash, value);
        V oldValue;
        try {
            // 获得Segment中的HashEntry数组
            HashEntry<K,V>[] tab = table;
            // 要插入的位置
            int index = (tab.length - 1) & hash;
            // index位置上的第一个HashEntry
            HashEntry<K,V> first = entryAt(tab, index);
            // 遍历
            for (HashEntry<K,V> e = first;;) {
                if (e != null) {
                    K k;
                    // 如果找到了key相同的,就覆盖值
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        oldValue = e.value;
                        if (!onlyIfAbsent) {
                            e.value = value;
                            ++modCount;
                        }
                        break;
                    }
                    e = e.next;
                }
                else {  // 没找到相同的
                    if (node != null)  
                        node.setNext(first);  // 头插法插入
                    else   // 如果桶中为null
                        node = new HashEntry<K,V>(hash, key, value, first);   // 创建HashEntry
                    int c = count + 1;
                    if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                        rehash(node);  // 超过阈值进行扩容
                    else
                        setEntryAt(tab, index, node);  // 设置index位置新元素
                    ++modCount;
                    count = c;
                    oldValue = null;
                    break;
                }
            }
        } finally {
            unlock();  // 解锁
        }
        return oldValue;
    }

    // 加锁
    private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
        // 得到相应位置上的第一个HashEntry
        HashEntry<K,V> first = entryForHash(this, hash);
        HashEntry<K,V> e = first;
        HashEntry<K,V> node = null;
        int retries = -1; // negative while locating node
        // 尝试次数
        while (!tryLock()) {
            HashEntry<K,V> f; // to recheck first below
            if (retries < 0) {
                if (e == null) {   // 如果位置上本身没有HashEntry
                    if (node == null) // speculatively create node
                        node = new HashEntry<K,V>(hash, key, value, null);   // 创建HashEntry
                    retries = 0;
                }
                else if (key.equals(e.key))  // 找到了一样的
                    retries = 0;
                else
                    e = e.next;   // 寻找下一个节点
            }
            else if (++retries > MAX_SCAN_RETRIES) {
                lock();    // 如果超过了最大的尝试次数,就直接进入队列排队
                break;
            }
            // 第一个节点发生变化的话就重新获取
            else if ((retries & 1) == 0 &&
                        (f = entryForHash(this, hash)) != first) {
                e = first = f; // re-traverse if entry changed
                retries = -1;
            }
        }
        return node;
    }

put方法的整体思路就是先根据Hash值找到对应的段,再在段中的HashEntry数组中进行寻找。其中代码

int j = (hash >>> segmentShift) & segmentMask;
int index = (tab.length - 1) & hash;

就是计算所在的段和段中的位置,计算方式可见探究ConcurrentHashMap中键值对在Segment[]的下标如何确定。我们来梳理下整个put方法的流程:

  1. 计算hash值,计算所在的段;
  2. 如果段还没被初始化,就根据第0个段为模版进行创建;
  3. 尝试加锁;
  4. 到Segment中HashEntry数组的对应位置去寻找是否有相同的key,有就直接覆盖。没有就利用头插法插入;
  5. 检查是否需要扩容,超过阈值则进行扩容
  6. 解锁

5. rehash方法

    // Segment内HashEntry数组的扩容方法
    private void rehash(HashEntry<K,V> node) {
        HashEntry<K,V>[] oldTable = table;  // 旧数组
        int oldCapacity = oldTable.length;  // 旧长度
        int newCapacity = oldCapacity << 1;  // 新容量为原来的2倍
        threshold = (int)(newCapacity * loadFactor);  // 新阈值
        HashEntry<K,V>[] newTable =
            (HashEntry<K,V>[]) new HashEntry[newCapacity]; // 创建新数组
        int sizeMask = newCapacity - 1;
        for (int i = 0; i < oldCapacity ; i++) {  // 遍历旧数组
            HashEntry<K,V> e = oldTable[i];  
            if (e != null) {  // 如果e不为null
                HashEntry<K,V> next = e.next;  // next
                int idx = e.hash & sizeMask;  // 计算在新数组中的位置
                if (next == null)   //  如果本身只有一个HashEntry
                    newTable[idx] = e;  // 直接插入到新数组
                else { // Reuse consecutive sequence at same slot
                    HashEntry<K,V> lastRun = e;  // 最后一次运行的HashEntry
                    int lastIdx = idx;  // lastRun对应要插入的位置
                    for (HashEntry<K,V> last = next;
                            last != null;
                            last = last.next) {
                        int k = last.hash & sizeMask;
                        if (k != lastIdx) {   // 如果和lastIdx不一样
                            lastIdx = k;    // 进行替换
                            lastRun = last;
                        }
                    }
                    // lastRun保证了后面的HashEntry在新数组中都是相同位置,减少了循环次数
                    newTable[lastIdx] = lastRun;
                    // Clone remaining nodes
                    // 从开头到lastRun
                    for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                        // 重新计算要插入的位置
                        V v = p.value;
                        int h = p.hash;
                        int k = h & sizeMask;
                        HashEntry<K,V> n = newTable[k];
                        // 迁移到新数组上
                        newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                    }
                }
            }
        }
        // 头插法插入新元素
        int nodeIndex = node.hash & sizeMask; // add the new node
        node.setNext(newTable[nodeIndex]);
        newTable[nodeIndex] = node;
        table = newTable;
    }

在ConcurrentHashMap1.7的扩容方法中,扩容大小是原来的两倍。首先创建新的HashEntry数组,然后逐个遍历旧数组中的HashEntry,进行迁移即可。需要注意的是,扩容的方法是针对Segment中的HashEntry数组的,不是对Segment进行扩容。

6. get方法

    public V get(Object key) {
        Segment<K,V> s; 
        HashEntry<K,V>[] tab;
        int h = hash(key.hashCode());   // 计算hash值
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;   // 计算key所在的段
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&    
            (tab = s.table) != null) {   // 获取段
            // 获取段中对应的桶,并遍历桶中的HashEntry
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);   
                 e != null; e = e.next) {
                K k;
                // 找到了相同的key,就返回
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

get方法也比较简单,没有加锁,就先去寻找所在的段,然后再在段中寻找,找到就返回。

7. remove方法

    // 按照key删除的remove方法
    public V remove(Object key) {
        int hash = hash(key.hashCode());
        // 找到所在的段
        Segment<K,V> s = segmentForHash(hash);
        // remove方法
        return s == null ? null : s.remove(key, hash, null);
    }

    final V remove(Object key, int hash, Object value) {
        if (!tryLock())   // 进行加锁
            scanAndLock(key, hash);
        V oldValue = null;
        try {
            HashEntry<K,V>[] tab = table;  // 获取HashEntry数组
            int index = (tab.length - 1) & hash;  // 计算位置
            HashEntry<K,V> e = entryAt(tab, index);  // 找到index中第一个
            HashEntry<K,V> pred = null;
            while (e != null) {
                K k;
                HashEntry<K,V> next = e.next;
                // 如果找到了就进行删除
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    V v = e.value;
                    if (value == null || value == v || value.equals(v)) {
                        if (pred == null)  // 删除的是第一个节点
                            setEntryAt(tab, index, next);
                        else  // 删除的是中间节点
                            pred.setNext(next);
                        ++modCount;
                        --count;
                        oldValue = v;
                    }
                    break;
                }
                pred = e;
                e = next;
            }
        } finally {
            unlock();  // 解锁
        }
        return oldValue;
    }

    private void scanAndLock(Object key, int hash) {
        // similar to but simpler than scanAndLockForPut
        HashEntry<K,V> first = entryForHash(this, hash);
        HashEntry<K,V> e = first;
        int retries = -1;
        while (!tryLock()) {   // 查实加锁
            HashEntry<K,V> f;
            if (retries < 0) {
                if (e == null || key.equals(e.key))  // 找到了
                    retries = 0;
                else  // 不断往后寻找
                    e = e.next;
            }
            else if (++retries > MAX_SCAN_RETRIES) {  // 如果超过最大尝试次数
                lock();   // 进入队列去排队
                break;
            }
            // 如果发现桶中第一个节点发生改变,就重新开始
            else if ((retries & 1) == 0 &&
                        (f = entryForHash(this, hash)) != first) {
                e = first = f;
                retries = -1;
            }
        }
    }

remove方法是先通过key计算出可能在的Segment,然后到Segment中的HashEntry数组中去寻找,找到就删除节点即可。在删除过程中需要进行加锁。

8. size方法

    public int size() {
        final Segment<K,V>[] segments = this.segments;  // 获取Segment数组
        int size;   // 初始化size
        boolean overflow; // 是否溢出
        long sum;     // 计算modCount修改次数
        long last = 0L;   // previous sum
        int retries = -1; // 重试次数
        try {
            for (;;) {
                // 如果超过了一定的重试次数,就同时把所有段都锁起来
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); // force creation
                }
                sum = 0L;
                size = 0;
                overflow = false;
                // 遍历每一个段
                for (int j = 0; j < segments.length; ++j) {
                    // 获取段
                    Segment<K,V> seg = segmentAt(segments, j);
                    // 计算每个段中的HashEntry个数
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                // 如果发现当前的modCount和上一次一样,说明没有线程安全问题
                if (sum == last)
                    break;  // 就可以结束了
                last = sum;   // 如果发现不一样,说明有线程修改了
            }
        } finally {
            // 解锁
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

size方法的思路是先假设没有线程安全问题,进行一定次数的尝试,如果本次计数时的修改次数和上一次一样,那就认为这个size是可信的,就返回。如果发现modCount不一样,那就说明有线程在本线程计数的时候进行了修改,需要重新计数。如果当重试次数超过一定次数了,说明线程竞争激烈,就会去把所有的Segment同时加锁,保证size计算没问题,这时的并发效率就很低了。

9. 总结

ConcurrentHashMap1.7的设计思想还是很精妙的,值得我们学习。它将所有的HashEntry分配到多个Segment上,当进行put,remove等修改操作的时候,不需要锁整个ConcurrentHashMap,只需要锁修改的HashEntry所在的段,一定程度上提高了并发的效率。

对于ConcurrentHashMap1.7的扩容,只能对Segment内的HashEntry数组进行扩容,不能增加Segment的个数。

参考文章:
ConcurrentHashMap1.7 最最最最最详细源码分析
探究ConcurrentHashMap中键值对在Segment[]的下标如何确定
翻了ConcurrentHashMap1.7 和1.8的源码,我总结了它们的主要区别。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值