ConcurrentHashMap实现原理(很详细)

ConcurrentHashMap实现原理

ConcurrentHashMap在1.7和1.8实现原理不同,就好像HashMap的不同实现一样。

我们先看看为什么HashMap不安全

不安全的HashMap

JDK1.7

我们知道1.7的HashMap底层是数组+链表,那么他不安全的地方就在于链表可能会成环。

实际上这个情况是发生在HashMap的扩容过程中:

void transfer(Entry[] newTable, boolean rehash) {
    int newCapacity = newTable.length;
    //table 是hashmap存储元素的链表数组,对这个数组进行循环
    for (Entry<K,V> e : table) {
        //每次循环拿出链表,链表内的节点不为空的时候
        while(null != e) {
            //此时next指向的是2
            Entry<K,V> next = e.next;
            //判断是否需要rehash解决冲突
            if (rehash) {
                //如果他的哈希不同就不用rehash,否则rehash
                e.hash = null == e.key ? 0 : hash(e.key);
            }
            //计算扩容应该存储的新的下标的位置
            int i = indexFor(e.hash, newCapacity);
            //先将当前的指针指向计算后的位置上的链表,也就是头插
            e.next = newTable[i];
            //然后将当前计算好的位置与当前链表做连接
            newTable[i] = e;
            //进行下次循环
            e = next;
        }
    }
}

1.7采用的扩容过程是链表的头插法,也就是说会将链表倒序放在适合的位置上,那为什么会出现线程不安全,也就是成环呢?我们试想下面的情景(扩容后计算的新下标i和之前该链表所在的位置相同,也就是不变的情况下,并且每个线程已经将该条链所有的节点已经插入完毕):

假设我们此时的链表是这样的:

image-20230409145937550

假设:此时有两个线程,线程1和线程2,他们都同时进入了while循环内。

  • 线程1运行到17行的时候,时间片运行结束,此时e.next将被赋值给 newTable[i],那么如果此时刚好是while刚开始循环,e为A,e.next=B,newTable[i]为空,此时e.next还不是newTable[i]

image-20230410105301983

此时线程1被挂起了

  • 接着线程二开始运行

因为while循环之前,for循环内Entry<K,V> e是已经被拿到了的,所以此时的e链表结构也是这样的:

image-20230409145937550

接着进行扩容操作,虽然现在是这样的,但是线程1被挂起之后,还没刷到主存,所以线程2进行扩容操作,然后被扩容完成,最后的结果是这样的:

image-20230410105545070

结束之后被刷到主存

  • 随后线程1被唤醒,继续运行,此时e.next将指向newTable[i],此时就是这样了

image-20230410110525629

接着B进行头插,这样乍一看好像没问题

image-20230410110458663

但是事实真的如此吗,这可是链表啊,从A指向C就开始出问题了吧,实际上是这样的:

image-20230410110644087

没错,成环了,这就是JDK7成环的原因

下面的代码可以详细的理解JDK7中具体的扩容机制,其中扩容的方法是JDK7的扩容方法,改了些内容,其他的都是参考HashMap的实现,可以自己在扩容赋值的那几步打断点看看整个过程

import java.util.HashMap;

/**
 * @author 我见青山多妩媚
 * @date 2023/4/9 15:13
 * @Description TODO
 */
public class HashMapTest {
    public static void main(String[] args){
        Entry<String,Integer>[] table = new Entry[3];
        int size = table.length;
        //设置链表
        HashMapJDK7<String,Integer> map = new HashMapJDK7<>(table,size);
        //设置三条链表
        Entry<String,Integer> entry1 = new Entry<>("A",1
                ,new Entry<>("B",2
                ,new Entry<>("C",3,null)));
//
//        Entry<String,Integer> entry2 = new Entry<>("O",1
//                ,new Entry<>("P",2
//                ,new Entry<>("Q",3,null)));
//
//        Entry<String,Integer> entry3 = new Entry<>("X",1
//                ,new Entry<>("Y",2
//                ,new Entry<>("Z",3,null)));

        //插入链表内
        map.insert(entry1,0);
//        map.insert(entry2,1);
//        map.insert(entry3,2);

        map.foreach();
        
        System.out.println("=============================");
        //扩容后再遍历一次
        map.foreach();

    }
}

class Entry<K,V>{
    K key;
    V value;
    Entry<K,V> next;

    public Entry(K key,V value, Entry<K,V> next){
        this.key = key;
        this.value = value;
        this.next = next;
    }
}
class HashMapJDK7<K,V>{
    private Entry<K,V>[] table;
    private int size ;

    public HashMapJDK7(){

    }

    public HashMapJDK7(Entry<K,V>[] table,int size){
        this.size = size;
        this.table = new Entry[size];
    }

    //插入链表,在某个位置
    public void insert(Entry<K,V> entry,int index){
        if(index >= size) throw new RuntimeException("下标太大了");
        if(table[index] == null){
            table[index] = entry;
            return;
        }
        Entry<K,V> head = table[index];
        while(head.next != null){
            head = head.next;
        }
        head.next = entry;
    }

    //链表迁移,获取新链表
    public void rehash(){
        Entry<K,V>[] newTable = new Entry[size+1];
        transfer(newTable,false);
        table = newTable;
        size = newTable.length;
    }

    public void foreach(){
        for (int i = 0; i < table.length; i++) {
            Entry<K,V> e = table[i];
            while(null != e){
                System.out.println("index="+i+",key="+ e.key+",value="+ e.value);
                e = e.next;
            }
        }
        System.out.println("size="+size);
    }

    private void transfer(Entry<K,V>[] newTable, boolean rehash) {
        int newCapacity = newTable.length;
        //table 是hashmap存储元素的链表数组,对这个数组进行循环
        int index = 0;
        for (Entry<K,V> e : table) {
            //每次循环拿出链表,链表内的节点不为空的时候
            int i = index;
            while(null != e) {
                //此时next指向的是2
                Entry<K,V> next = e.next;
                //判断是否需要rehash解决冲突
//                if (rehash) {
//                    //如果他的哈希不同就不用rehash,否则rehash
//                    e.hash = null == e.key ? 0 : hash(e.key);
//                }
//                //计算扩容应该存储的新的下标的位置
//                int i = indexFor(e.hash, newCapacity);

                //先将当前的指针指向计算后的位置上的链表,也就是头插
                e.next = newTable[i];
                //然后将当前计算好的位置与当前链表做连接
                newTable[i] = e;
                //进行下次循环
                e = next;
            }
            index++;
        }
    }
}
JDK1.8

那么1.8是什么原因呢?因为1.8实现的方式为数组+链表+红黑树,而且扩容的方法也改了,所以就不会出现成环的情况。jdk8不安全主要是因为数据覆盖问题。

类似我们的put同一个键,不同的线程put不同的键的时候,如果计算的hash相同,都用尾插法添加到链表末尾,肯定会出现覆盖问题,因为两个同时来,之前检查的时候肯定是相同的,此时就会出现替换问题。

安全的ConcurrentHashMap

JDK1.7

JDK1.7实现的方式为分段锁的思想,不像HashTable给所有的put\get加了synchronized,多线程条件下效率很差。

他的分段锁是容器内有多把锁,这样在多线程条件下访问不同数据时就尽量减少锁竞争

类似HashMap,ConcurrentHashMap实现方式为数组+链表,但是他的数组有些不同。

ConcurrentHashMap的结构是Segment数组 + HashEntry数组 + 链表组成,他的原理实际上是一个hash表内存了一个hash表,segment数组的作用是存储HashEntry所在位置,然后在HashEntry查找对应元素所在hash位置。

所以它的结构如下:

image-20230410150821647

  • ConcurrentHashMap内属性
public class ConcurrentHashMap<K,V> extends AbstractMap<K,V>
    implements ConcurrentMap<K,V>, Serializable {
    
    final Segment<K,V>[] segments;
    transient Set<K> keySet;
    transient Set<Map.Entry<K,V>> entrySet;
    
    //...
}
  • Segment
static final class Segment<K,V> extends ReentrantLock implements Serializable {
    private static final long serialVersionUID = 2249069246763182397L;
	
	// 和 HashMap 中的 HashEntry 作用一样,真正存放数据的桶
	transient volatile HashEntry<K,V>[] table;
	
	transient int count;
	transient int modCount;
	transient int threshold;
	final float loadFactor;
	//.....
}

可见,segment是继承了ReetrantLock的

  • HashEntry<K,V>
static final class HashEntry<K,V>{
    final int hash;
    final K key;
    volatile V value;
    volatile HashEntry<K,V> next;
}

其实和Map中的Entry类似,只是放数据,区别是这个拿volatile修饰了些变量

了解了这些,我们接下来直接看对应的操作

  • put
   public V put(K key, V value) {
        Segment<K,V> s;
        //ConcurrentHashMap的key和value都不能为null
        if (value == null)
            throw new NullPointerException();

        //这里对key求hash值,并确定应该放到segment数组的索引位置
        int hash = hash(key);
        //j为索引位置,思路和HashMap的思路一样,这里不再多说
        int j = (hash >>> segmentShift) & segmentMask;
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            s = ensureSegment(j);
        //这里很关键,找到了对应的Segment,则把元素放到Segment中去
        return s.put(key, hash, value, false);
    }
  • s.put(segment的方法)
 final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            //这里是并发的关键,每一个Segment进行put时,都会加锁
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);  //获取锁失败进行自旋获取锁
            V oldValue;
            try {
                //tab是当前segment所连接的HashEntry数组
                HashEntry<K,V>[] tab = table;
                //确定key的hash值所在HashEntry数组的索引位置
                int index = (tab.length - 1) & hash;
                //取得要放入的HashEntry链的链头
                HashEntry<K,V> first = entryAt(tab, index);
                //遍历当前HashEntry链
                for (HashEntry<K,V> e = first;;) {
                    //如果链头不为null
                    if (e != null) {
                        K k;
                        //如果在该链中找到相同的key,则用新值替换旧值,并退出循环
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        //如果没有和key相同的,一直遍历到链尾,链尾的next为null,进入到else
                        e = e.next;
                    }
                    else {//如果没有找到key相同的,则把当前Entry插入到链头

                        if (node != null)
                            node.setNext(first);
                        else
                            node = new HashEntry<K,V>(hash, key, value, first);
                        //此时数量+1
                        int c = count + 1;
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            //如果超出了限制,要进行扩容
                            rehash(node);
                        else
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                //最后释放锁
                unlock();
            }
            return oldValue;
        }
  1. 首先对key进行hash,确定分段锁segment的位置
  2. 然后进入put方法内,首先进行获取锁
  3. 接着获取到segment内的table数组,然后对长度和hash进行&运算,确定链表所在位置
  4. 然后对当前链表进行遍历,找到要插入的位子,如果是重复的key就替换;不是则插入
  5. 锁释放

共hash两次,确定最终位置

  • get
public V get(Object key){
    Segment<K,V> s;
    HashEntry<K,V>[] tab;
    int h = hash(key);
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    if((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segment,u)) != null && (tab = s.table) != null){
        for(HashEntry<K,V> e = (HashEntry<K,V>)UNSAFE.getObjectVolatile(tab,((long)(((tab.length-1)&h)) << TSHIFT)+TABSE);e!=null;e = e.next){
            K k;
            if((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}
  1. 通过hash定位到具体的segment
  2. 再通过一次Hash定位到具体的元素上
  3. 由于HashEntry中的value属性是用volatile修饰,保存可见性,所以每次获取都是最新的值

因为整个过程没加锁,所以效率还是高的

JDK1.8

到了JDK1.8之后,又有哪些改进呢?

首先可HashMap1.8一样,改为数组+链表+红黑树

与1.7的ConcurrentHashMap的区别是HashEntry改为了Node,并且抛弃了segment分段锁的思想,改为CAS+Synchronized来保证并发安全。

public class ConcurrentHashMap<K,V> extends AbstractMap<K,V>
    implements ConcurrentMap<K,V>, Serializable {
    //存储数据的table
    transient volatile Node<K,V>[] table;
    
    //...
}
static class Node<K,V> implements Map.Entry<K,V> {
        final int hash;
        final K key;
        volatile V val;
        volatile Node<K,V> next;
}

接下来看具体的操作:

  • put
public V put(K key, V value) {
        return putVal(key, value, false);
}
  • putValue
final V putVal(K key, V value, boolean onlyIfAbsent) {
    	//key value不为空
        if (key == null || value == null) throw new NullPointerException();
    	//计算hash值
        int hash = spread(key.hashCode());
    	//统计节点长度,以便判断是否需要转为红黑树
        int binCount = 0;
    	//table
        for (Node<K,V>[] tab = table;;) {
            //f 具体某个节点,i为计算后的tab下标
            //fh为hash值的临时存储
            Node<K,V> f; int n, i, fh;
            //如果当前table为空,初始化这个table
            if (tab == null || (n = tab.length) == 0)
                tab = initTable();
            //如果当前下标的Node为空,说明是第一次插入这个位置
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                //那么尝试CAS获取这个位置,并初始化这个节点,属性为hash,key,value
                if (casTabAt(tab, i, null,
                             new Node<K,V>(hash, key, value, null)))
                    //CAS成功,跳出判断
                    break;                   // no lock when adding to empty bin
            }
            //如果hash值==-1,说明需要扩容
            else if ((fh = f.hash) == MOVED)
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
                synchronized (f) {
                    //如果已经被赋值给过f
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) {
                            binCount = 1;
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                                //如果key冲突就替换
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K,V> pred = e;
                                //往下继续循环
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key,
                                                              value, null);
                                    break;
                                }
                            }
                        }
                        //如果链表是红黑树,那么就执行红黑树的put方法
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                           value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                //是否需要转化为红黑树
                if (binCount != 0) {
                    if (binCount >= TREEIFY_THRESHOLD) //TREEIFY_THRESHOLD = 8
                        //转红黑树
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        addCount(1L, binCount);
        return null;
    }
  1. 根据key计算hash,确定位置
  2. 判断table是否为空,是否需要初始化
  3. f为定位出的Node,尝试使用CAS+自旋写入
  4. 如果当前位置的hash==-1,那么需要扩容
  5. 如果都不满足,用synchronized写入数据
  6. 如果数量大于TREEIFY_THRESHOLD(8)转化为红黑树
  • get
public V get(Object key) {
        Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
    	//计算hash
        int h = spread(key.hashCode());
    	//当前hash所在tab不为空
        if ((tab = table) != null && (n = tab.length) > 0 &&
            //e指向当前hash的Node
            (e = tabAt(tab, (n - 1) & h)) != null) {
            //如果e.hash和h相同,说明当前链表首节点就是
            if ((eh = e.hash) == h) {
                //如果key相同,那么返回
                if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                    return e.val;
            }
            //如果小于0,就是说hash计算的小于0,那么有两种情况
            //1. 是红黑树 hash桶固定为-2,那么调用TreeBin的find,查找
            //2. ConcurrentMap在扩容中,跳转到扩容后的数组中查询
            //find方法TreeBin有,Node节点内部也有,具体用哪个视情况而定,主要看e是什么了
            else if (eh < 0)
                return (p = e.find(h, key)) != null ? p.val : null;
            //都不满足,e开始自己找
            while ((e = e.next) != null) {
                //找到返回
                if (e.hash == h &&
                    ((ek = e.key) == key || (ek != null && key.equals(ek))))
                    return e.val;
            }
        }
    	//都不满足返回null
        return null;
    }
  1. 计算hash,如果在桶上直接返回
  2. 如果是红黑树就按照红黑树获取
  3. 否则链表查找

总结

区别:

  • 1.7用的数组+链表,安全的方式为分段锁

  • 1.8用的数组+链表+红黑树,结构和HashMap一致,抛弃了分段锁,采用CAS+自旋以及Synchronized

为什么key为null的时候会报错?

因为如果key为null,无法分辨是key为null还是key无法找到返回为null,这在多线程下时模糊的

  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值