ConcurrentHashMap源码阅读

前言

ConcurrentHashMap是HashMap的多线程版本,经常用到,JDK里的实现方式也非常的精妙,值得学习。JDK1.7和1.8的实现方式并不相同,所以这里两个版本都要学习,体会个中的精妙之处。

HashMap的各种多线程方式

ConcurrentHashMap并非是HashMap的唯一多线程方式,它还有其他的多线程方式,为什么这些都被淘汰了,与ConcurrentHashMap的区别又在哪里?

Hashtable

Hashtable是线程安全的,但效率低下,因为它的源码就是直接对所有数据操作都上锁synchronized(直接在方法级别上加synchronized),哪怕是get方法,所以效率比较低下。除此之外,Hashtable与HashMap还有一些不同。

关于null值,HashMap允许键值为null,而Hashtable不允许。这是因为Hashtable使用的是安全失败机制(fail-safe),这种机制会使得你此次读到的数据不一定是最新的数据。如果使用null值,就无法判断对应的key是不存在还是为null,ConcurrentHashMap同理。对于单线程的HashMap,可以通过containsKey来判断key是否存在,因此可以允许null值作为key和value。而对于多线程的Hashtable和ConcurrentHashMap,在get和containsKey两个方法中间,可能map本身就发生了改变,此时containsKey方法就失去了意义,因此是不能判断map到底是存在key为null的数据,还是不存在该key。

实现方式不同,Hashtable继承了Dictionary类,而HashMap继承的是AbstractMap类。

初始化容量不同,HashMap初始容量是16,而Hashtable是11,二者的负载因子都是0.75

扩容机制不同,HashMap扩容是翻倍,而Hashtable是翻倍再+1

迭代器不同,HashMap的Iterator是fail-fast的,而Hashtable的Enumerator是fail-safe的。所以HashMap使用迭代器后,如果修改了HashMap的结构,如增加,删除元素,将会抛出ConcurrentModificationException异常,而Hashtable不会。而HashMap是根据参数modCount来判断是否发现了HashMap的改变,异常抛出的条件是modCount != exceptedmodCount,如果集合发生变化时,modCount刚好修改为exceptedmodCount,那么是不会抛出该异常的。因此不能依赖这个异常是否抛出来进行并发编程,这个异常只建议用来检测并修改的bug。

使用Collections.synchronizedMap(Map)创建线程安全的Map集合

此方法会返回一个SynchronizedMap对象,里面维护了一个普通的Map对象,还有互斥锁Mutex。

image-20200901105706961

我们在调用这个方法的时候就需要传入一个Map,可以看到有两个构造器,如果你传入了mutex参数,则将对象排斥锁赋值为传入的对象。

如果没有,则将对象排斥锁赋值为this,即调用synchronizedMap的对象,就是上面的Map。

创建出synchronizedMap之后,再操作map的时候,就会对方法上锁,如图,全是🔐

image-20200901105804075

所以synchronizedMap的效率也不高。

ConcurrentHashMap

ConcurrentHashMap是并发度比较高的Map类,是多线程环境下常用的多线程版本的HashMap。在JDK1.7和JDK1.8中的实现并不一样,下面将分别讲述。

JDK1.7的ConcurrentHashMap

数据结构

image-20200901144242514

HashMap的底层是由链表数组而组成的。因为链表在哈希的时候被称作bucket(桶),所以下文有时候会把链表写作桶,二者是等价的。在JDK1.7中,ConcurrentHashMap由Segment数组构成,而Segment由HashEntry构成。采用的思想是分段锁,即一个Segment相当于一个HashMap,Segment里的HashEntry就是HashMap里的桶数组table。一个ConcurrentHashMap由一个Segment数组构成,分段锁的思想就是每一个Segment是相互独立的,即对其中一个Segment操作时,并不会锁住其他Segment的数据。因此ConcurrentHashMap的并发度就是Segment数组的长度。

Segment的ConcurrentHashMap的一个内部类,如下:

static final class Segment<K,V> extends ReentrantLock implements Serializable {

  private static final long serialVersionUID = 2249069246763182397L;

  // 和 HashMap 中的 HashEntry 作用一样,真正存放数据的桶
  transient volatile HashEntry<K,V>[] table;

  transient int count;
  // 记得快速失败(fail—fast)么?
  transient int modCount;
  // 大小
  transient int threshold;
  // 负载因子
  final float loadFactor;

}

HashEntry跟HashMap差不多的,但是不同点是,他使用volatile去修饰了它的数据Value以及下一个节点next。

Question:JDK1.7的ConcurrentHashMap为什么并发度高?

Ans:原理上来说,ConcurrentHashMap 采用了分段锁技术,其中 Segment 继承于 ReentrantLock。

在 JDK1.7中,本质上还是采用链表+数组的形式存储键值对的。但是,为了提高并发,把原来的整个 table 划分为 n 个 Segment 。所以,从整体来看,它是一个由 Segment 组成的数组。然后,每个 Segment 里边是由 HashEntry 组成的数组,每个 HashEntry之间又可以形成链表。我们可以把每个 Segment 看成是一个小的 HashMap,其内部结构和 HashMap 是一模一样的。

当对某个 Segment 加锁时,如上图中 Segment2,此时并不会影响到其他 Segment 的读写,每个 Segment 内部自己操作自己的数据,彼此之间互相独立。这样一来,我们要做的就是尽可能的让元素均匀的分布在不同的 Segment中。最理想的状态是,所有执行的线程操作的元素都是不同的 Segment,这样就可以降低锁的竞争。

常用变量

//默认初始化容量,这个和 HashMap中的容量是一个概念,表示的是整个 Map的容量
static final int DEFAULT_INITIAL_CAPACITY = 16;

//默认加载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;

//默认的并发级别,这个参数决定了 Segment 数组的长度
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

//最大的容量
static final int MAXIMUM_CAPACITY = 1 << 30;

//每个Segment中table数组的最小长度为2,且必须是2的n次幂。
//由于每个Segment是懒加载的,用的时候才会初始化,因此为了避免使用时立即调整大小,设定了最小容量2
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

//用于限制Segment数量的最大值,必须是2的n次幂
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative

//在size方法和containsValue方法,会优先采用乐观的方式不加锁,直到重试次数达到2,才会对所有Segment加锁
//这个值的设定,是为了避免无限次的重试。后边size方法会详讲怎么实现乐观机制的。
static final int RETRIES_BEFORE_LOCK = 2;

//segment掩码值,用于根据元素的hash值定位所在的 Segment 下标。后边会细讲
final int segmentMask;

//和 segmentMask 配合使用来定位 Segment 的数组下标,后边讲。
final int segmentShift;

// Segment 组成的数组,每一个 Segment 都可以看做是一个特殊的 HashMap
final Segment<K,V>[] segments;

//Segment 对象,继承自 ReentrantLock 可重入锁。
//其内部的属性和方法和 HashMap 神似,只是多了一些拓展功能。
static final class Segment<K,V> extends ReentrantLock implements Serializable {

  //这是在 scanAndLockForPut 方法中用到的一个参数,用于计算最大重试次数
  //获取当前可用的处理器的数量,若大于1,则返回64,否则返回1。
  static final int MAX_SCAN_RETRIES =
    Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

  //用于表示每个Segment中的 table,是一个用HashEntry组成的数组。
  transient volatile HashEntry<K,V>[] table;

  //Segment中的元素个数,每个Segment单独计数(下边的几个参数同样的都是单独计数)
  transient int count;

  //每次 table 结构修改时,如put,remove等,此变量都会自增
  transient int modCount;

  //当前Segment扩容的阈值,同HashMap计算方法一样也是容量乘以加载因子
  //需要知道的是,每个Segment都是单独处理扩容的,互相之间不会产生影响
  transient int threshold;

  //加载因子
  final float loadFactor;

  //Segment构造函数
  Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
    this.loadFactor = lf;
    this.threshold = threshold;
    this.table = tab;
  }

  ...
    // put(),remove(),rehash() 方法都在此类定义
}

// HashEntry,存在于每个Segment中,它就类似于HashMap中的Node,用于存储键值对的具体数据和维护单向链表的关系
static final class HashEntry<K,V> {
  //每个key通过哈希运算后的结果,用的是 Wang/Jenkins hash 的变种算法,此处不细讲,感兴趣的可自行查阅相关资料
  final int hash;
  final K key;
  //value和next都用 volatile 修饰,用于保证内存可见性和禁止指令重排序
  volatile V value;
  //指向下一个节点
  volatile HashEntry<K,V> next;

  HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
    this.hash = hash;
    this.key = key;
    this.value = value;
    this.next = next;
  }
}

重点:

concurrencyLevel:并发级别,相当于Segment数组的最大长度。

segmentMask:用于定位到Segment数组的下标。

segmentShift:用于定位到具体segment的具体下标(即具体元素的下标)。

Segment继承自ReentrantLock,维护了一个HashEntry数组(桶数组),还有countmodCountthresholdloadFactor等变量,put,remove,rehash等方法。而HashEntry就相当于HashMap中的Node,维护了hash,key,value,next等变量。

构造方法

public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
  //检验参数是否合法。值得说的是,并发级别一定要大于0,否则就没办法实现分段锁了。
  if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
    throw new IllegalArgumentException();
  //并发级别不能超过最大值
  if (concurrencyLevel > MAX_SEGMENTS)
    concurrencyLevel = MAX_SEGMENTS;
  // Find power-of-two sizes best matching arguments
  //偏移量,是为了对hash值做位移操作,计算元素所在的Segment下标,put方法详讲
  int sshift = 0;
  //用于设定最终Segment数组的长度,必须是2的n次幂
  int ssize = 1;
  //这里就是计算 sshift 和 ssize 值的过程  (1) 
  while (ssize < concurrencyLevel) {
    ++sshift;
    ssize <<= 1;
  }
  this.segmentShift = 32 - sshift;			// 最后用高sshift位来记录位移量
  //Segment的掩码
  this.segmentMask = ssize - 1;				// 类似于 n - 1的掩码
  if (initialCapacity > MAXIMUM_CAPACITY)
    initialCapacity = MAXIMUM_CAPACITY;
  //c用于辅助计算cap的值   (2)
  int c = initialCapacity / ssize;
  if (c * ssize < initialCapacity)
    ++c;
  // cap 用于确定某个Segment的容量,即Segment中HashEntry数组的长度
  int cap = MIN_SEGMENT_TABLE_CAPACITY;
  //(3)
  while (cap < c)
    cap <<= 1;
  // create segments and segments[0]
  //这里用 loadFactor做为加载因子,cap乘以加载因子作为扩容阈值,创建长度为cap的HashEntry数组,
  //三个参数,创建一个Segment对象,保存到S0对象中。后边在 ensureSegment 方法会用到S0作为原型对象去创建对应的Segment。
  Segment<K,V> s0 =
    new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                     (HashEntry<K,V>[])new HashEntry[cap]);
  //创建出长度为 ssize 的一个 Segment数组
  Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
  //把S0存到Segment数组中去。在这里,我们就可以发现,此时只是创建了一个Segment数组,
  //但是并没有把数组中的每个Segment对象创建出来,仅仅创建了一个Segment用来作为原型对象。
  UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
  this.segments = ss;
}    

构造方法所做的工作:

①计算sshift和ssize的值。

sshift:hash位移,用于计算元素所在的Segment下标,在后面put方法会用到。

ssize:用于计算最终Segment数组的长度,是一个不小于concurrencyLevel的二次幂。转换过程为代码中的(1)

②(2)和(3),通过ssize确定了HashEntry数组的最终长度(cap),可以看到HashEntry数组的长度也是2次幂。

创建Segment数组,Segment数组的长度为ssize。同时还创建了一个s0作为原型对象,用于后续创建新的Segment,三个参数分别为负载因子loadFactor,扩容阈值(cap * loadFactor),长度为cap的HashEntry数组。将s0存到数组里。

put方法

put 方法的总体流程:

  1. 通过哈希算法计算出当前 key 的 hash 值
  2. 通过这个 hash 值找到它所对应的 Segment 数组的下标(在哪个Segment里)
  3. 再通过 hash 值计算出它在对应 Segment 的 HashEntry数组的下标(在Segment里的具体元素下标)
  4. 找到合适的位置插入元素(具体元素下标里是一个桶/链表,遍历一次,如果有相同的key,说明是替换value,如果到达结尾还没有匹配到相同的key,说明是插入新的key-value)
// Map的put方法
public V put(K key, V value) {
  Segment<K,V> s;
  //不支持value为空
  if (value == null)
    throw new NullPointerException();
  //通过 Wang/Jenkins 算法的一个变种算法,计算出当前key对应的hash值
  int hash = hash(key);
  //上边我们计算出的 segmentShift为28,因此hash值右移28位,说明此时用的是hash的高4位,
  //然后把它和掩码15进行与运算,得到的值一定是一个 0000 ~ 1111 范围内的值,即 0~15 。
  int j = (hash >>> segmentShift) & segmentMask;
  //这里是用Unsafe类的原子操作找到Segment数组中j下标的 Segment 对象
  if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
       (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
    //初始化j下标的Segment
    s = ensureSegment(j);
  //在此Segment中添加元素
  return s.put(key, hash, value, false);		// 最终确定元素位置并插入元素
}

关于代码UNSAFE.getObject (segments, (j << SSHIFT) + SBASE,它是为了通过Unsafe这个类,找到 j 最新的实际值。这个计算( j << SSHIFT ) + SBASE,在后边非常常见,我们只需要知道它代表的是 j 的一个偏移量,通过偏移量,就可以得到 j 的实际值。可以类比,AQS 中的 CAS 操作。Unsafe中的操作,都需要一个偏移量,看下图:

image-20200901151317115

( j << SSHIFT ) + SBASE 就相当于图中的 stateOffset偏移量。只不过图中是 CAS 设置新值,而我们这里是取 j 的最新值,后边还有很多这样的计算方式。接着看 s.put 方法,这才是最终确定元素位置的方法。

// Segment中的 put 方法
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
  // 这里通过tryLock尝试加锁,如果加锁成功,返回null,否则执行 scanAndLockForPut方法
  // tryLock不会阻塞,而是自旋,直到自旋次数达到阈值,才调用lock方法进行阻塞等待
  HashEntry<K,V> node = tryLock() ? null :
  scanAndLockForPut(key, hash, value);
  V oldValue;
  try {
    // 当前Segment的table数组
    HashEntry<K,V>[] tab = table;
    // 这里就是通过hash值,与tab数组长度取模,找到其所在HashEntry数组的下标
    int index = (tab.length - 1) & hash;
    // 当前下标位置的第一个HashEntry节点
    HashEntry<K,V> first = entryAt(tab, index);
    for (HashEntry<K,V> e = first;;) {			// 遍历链表
      // 如果e节点不为空
      if (e != null) {
        K k;
        // 如果e的key相同,替换value值,否则继续向后查找
        if ((k = e.key) == key ||
            (e.hash == hash && key.equals(k))) {
          // 替换旧值
          oldValue = e.value;
          if (!onlyIfAbsent) {
            e.value = value;
            ++modCount;
          }
          break;			// 修改完成,终止循环,因为是替换value,无需修改count值
        }
        e = e.next;
      }
      // 执行到else有两种情况,一种是first为null,即链表为null。
      // 另一种是链表遍历到了结尾也没有找到key相同的节点。
      // 不管是哪一种,直接用头插法把node插在first节点的前面即可。
      else {
        // 还要先判断一下node是否为null,后面会看到scanAndLockForPut不一定完成node的初始化
        if (node != null)				// 如果node不为空,则直接头插
          node.setNext(first);
        //否则,创建一个新的node,并头插
        else
          node = new HashEntry<K,V>(hash, key, value, first);
        int c = count + 1;
        //如果当前Segment中的元素大于阈值,并且tab长度没有超过容量最大值,则扩容
        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
          rehash(node);
        //否则,就把当前node设置为index下标位置新的头结点
        else
          setEntryAt(tab, index, node);
        ++modCount;
        //更新count值
        count = c;
        //这种情况的旧值肯定为空
        oldValue = null;
        break;
      }
    }
  } finally {
    //需要注意ReentrantLock必须手动解锁
    unlock();
  }
  //返回旧值
  return oldValue;
}

Segment中的put方法逻辑:

使用tryLock尝试加锁。因为是在多线程环境下使用,所以要避免多个线程对同一个Segment同时进行put导致更新丢失等等的并发错误。HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);,如果tryLock成功,那么当前线程成功获取锁,此时node直接赋值为null。如果tryLock失败,说明有其他线程获取了锁,此时调用scanAndLockForPut方法。后面我们会看到,该方法逻辑如其名,会一遍一遍地扫描链表,会进行一些预热的Node生成操作,一直自旋重试,直到重试次数上限才调用lock方法,直接阻塞等待,停止自旋。所以这一句代码可以确保执行后续代码的时候,该线程必定获取了锁。而scanAndLockForPut可能会预生成一个Node,把它赋值给node变量,用于put的后续使用。

②在Map的put方法中已经定位到了具体的Segment,而这里就是在Segment里根据hash值定位到具体的HashEntry位置HashEntry<K,V> first = entryAt(tab, index);。记住这里的HashEntry是一个链表(桶),之后就是遍历链表,如果找到相同key的节点,就进行替换value。如果找不到,那么就新建一个node进行插入。

重点在于遍历链表的逻辑,for(…) { if (not null) … else …},逻辑并不难懂,但网上一些互相cv的博客,一个错就全都错,那些迷之注释我也不确定它们到底有没有看过是什么意思,害得我纠结了很久。实际上就是两种情况:如果匹配到相同的key,那么就是进行value的替换。如果直到链表尾部还是没有找到相同的key,那么就是进行新的key-value的插入。搞清楚整体的逻辑之后,对于这段代码我只有一个疑问:if ((k = e.key) == key || (e.hash == hash && key.equals(k))),为什么判断key相等要写成这么奇怪的形式。答案也很简单,减少equals方法的调用,毕竟调用一个方法肯定还是比直接数值的比较要消耗更多的资源。如果key值相同,那么该逻辑表达式就直接为true,否则就先判断hash,而不是直接调用equals方法。在JDK1.8,这段代码会改写为:if (e.hash == hash && ((k = e.key) == key) || key != null && key.equals(k)),直接先判断hash,如果hash都不相同,那么key肯定就不相同了(重写了hashCode)。如果hash相同,那么就是在同一个链表(桶)里,此时key的值并不一定相等,可能是发生了碰撞,因此还要判断key。先用key的数值去判断,如果返回的是false再用equals去判断。因为key类型不一定是基本类型,如果是String类型,此时(k = e.key) == key就会返回false。总而言之,这些写法都是尽量减少equals方法的调用。

对于1.7方法,采用的是分段锁,同时我们早已定位到了具体的桶位置,所以优先判断key的数值,再判断hash,虽然的确存在不同hash在&操作之后定位到了同一个桶,但概率相对来说比较小。而对于1.8,并没有分段锁的概念,因此优先判断hash来过滤掉大部分不合格的节点。

当然除此之外,put方法还调用了诸如ensureSegment,scanAndLockForPut,entryAt,setEntryAt等方法。当达到扩容阈值的时候还会调用rehash,后面会讲到。

Question:计算Segment数组下标和计算HashEntry数组下标有何不同?

Ans:

计算Segment数组下标: (hash >>> segmentShift) & segmentMask

计算HashEntry数组下标:(tab.length - 1) & hash

Segment使用的是hash的高位和掩码进行与运算,HashEntry直接使用hash和数组长度减1进行与运算。这样做可以避免同时使用低位相同的hash,与运算后的结果容易相同,导致元素扎堆,链表过长的缺点。只要两个hash它不是高位和低位都相同,那么二者计算的下标结果就会不同。(有点类似HashMap的高16位和低16位进行与运算)

ensureSegment方法

Map的put方法,如果对应下标的Segment对象为null,此时会调用ensureSegment方法,初始化一个Segment对象,以确保拿到的的对象一定不为null,然后再调用s.put方法。

//k为 (hash >>> segmentShift) & segmentMask 算法计算出来的值
private Segment<K,V> ensureSegment(int k) {
  final Segment<K,V>[] ss = this.segments;
  //u代表 k 的偏移量,用于通过 UNSAFE 获取主内存最新的实际 K 值
  long u = (k << SSHIFT) + SBASE; // raw offset
  Segment<K,V> seg;
  //从内存中取到最新的下标位置的 Segment 对象,判断是否为空,(1)
  if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
    //之前构造方法里说了,s0是作为一个原型对象,用于创建新的 Segment 对象
    Segment<K,V> proto = ss[0]; // use segment 0 as prototype
    //容量
    int cap = proto.table.length;
    //加载因子
    float lf = proto.loadFactor;
    //扩容阈值
    int threshold = (int)(cap * lf);
    //把 Segment 对应的 HashEntry 数组先创建出来
    HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
    //再次检查 K 下标位置的 Segment 是否为空, (2)
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
        == null) { // recheck
      //此处把 Segment 对象创建出来,并赋值给 s,
      Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
      //循环检查 K 下标位置的 Segment 是否为空, (3)
      //若不为空,则说明有其它线程抢先创建成功,并且已经成功同步到主内存中了,
      //则把它取出来,并返回
      while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
             == null) {
        //CAS,若当前下标的Segment对象为空,就把它替换为最新创建出来的 s 对象。
        //若成功,就跳出循环,否则,就一直自旋直到成功,或者 seg 不为空(其他线程成功导致)。
        if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
          break;
      }
    }
  }
  return seg;
}

ensureSegment方法的目标就是创建一个Segment对象,理论上应该很简单,但方法看起来还是写了挺长,原因在于方法里进行了三次判断判断Segment对象是否为null。因为在多线程环境下,不确定什么时候其他线程的CAS操作会成功,有可能发生在上面的任意时刻。所以只要有任意一次检测到Segment对象不为null,说明其他线程已经把Segment对象创建好了,并已经成功地CAS同步到内存了,此时就可以直接返回,无需再重复地创建。

一共有三次判断,第一次是在方法的最开始。如果为null,那么继续创建Segment对象所需要的HashEntry数组对象,创建完之后进行第二次判断。如果依然为null,那么就正式创建Segment对象,传入loadFactor,threshold以及前面创建的HashEntry对象。接着就是进行第三次判断,如果依然为null,说明对象依然没有创建完全并同步到内存,此时继续往后执行CAS操作。最后一步是自旋,等待CAS操作成功,每一次自旋都会判断一次Segment对象是否为null。

scanAndLockForPut方法

Segment里的tryLock失败后,就会调用此方法。

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
  //根据hash值定位到它对应的HashEntry数组的下标位置,并找到链表的第一个节点
  //注意,这个操作会从主内存中获取到最新的状态,以确保获取到的first是最新值
  HashEntry<K,V> first = entryForHash(this, hash);
  HashEntry<K,V> e = first;
  HashEntry<K,V> node = null;
  //重试次数,初始化为 -1
  int retries = -1; // negative while locating node
  //若抢锁失败,就一直循环,直到成功获取到锁。有三种情况
  while (!tryLock()) {
    HashEntry<K,V> f; // to recheck first below
    //1.若 retries 小于0,
    if (retries < 0) {
      if (e == null) {
        //若 e 节点和 node 都为空,则创建一个 node 节点。这里只是预测性的创建一个node节点
        if (node == null) // speculatively create node
          node = new HashEntry<K,V>(hash, key, value, null);
        retries = 0;
      }
      //如当前遍历到的 e 节点不为空,则判断它的key是否等于传进来的key,若是则把 retries 设为0
      else if (key.equals(e.key))
        retries = 0;
      //否则,继续向后遍历节点
      else
        e = e.next;
    }
    //2.若是重试次数超过了最大尝试次数,则调用lock方法加锁。表明不再重试,我下定决心了一定要获取到锁。
    //要么当前线程可以获取到锁,要么获取不到就去排队等待获取锁。获取成功后,再 break。
    else if (++retries > MAX_SCAN_RETRIES) {
      lock();
      break;
    }
    //3.若 retries 的值为偶数,并且从内存中再次获取到最新的头节点,判断若不等于first
    //则说明有其他线程修改了当前下标位置的头结点,于是需要更新头结点信息。
    else if ((retries & 1) == 0 &&
             (f = entryForHash(this, hash)) != first) {
      //更新头结点信息,并把重试次数重置为 -1,继续下一次循环,从最新的头结点遍历当前链表。
      e = first = f; // re-traverse if entry changed
      retries = -1;
    }
  }
  return node;
}

该方法的核心是自旋,直到tryLock成功,即直到成功获取锁。所以自旋次数达到MAX_SCAN_RETRIES,那么会停止自旋,直接调用lock方法,排队等待获取锁。从整体的逻辑可以看出来,该方法确保返回的时候,当前线程一定是成功获取锁的状态,所以前面put方法在最开始的tryLock就实现了锁的获取,保证线程同步。

自旋里首先判断retries是否小于0,小于0说明此时是第一次自旋,或者因为链表被修改了重新自旋。当进入到这一段逻辑,就是普通的链表遍历。如果e为null,说明链表为空,所以直接插入node节点即可。但因为在while循环体里,说明此时已经发生了并发问题,比如多个线程同时要put到一个空的链表(桶)里。所以此时会把retries置为0,那么就不会继续进入这一段逻辑,而是一直自旋,直到成功获取锁,或者达到最大尝试次数++retries > MAX_SCAN_RETRIES。如果e不为null,那么就继续遍历,找寻key相同的节点。如果存在相同的key,那么后续就是要进行value的更新。如果不存在,那么就继续往后遍历e = e.next。如果遍历到最后还是没有相同的key,那么此时就会执行到e == null处。除了这一段逻辑以及最大尝试次数的判断,还有一个判定条件是:else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first),这个主要就是用于实时更新头节点信息。因为put方法使用的是头插法,所以如果头节点更新了,说明其他线程进行了put操作,那么当前线程就要更新头节点信息,并且重新自旋,重新从头开始遍历链表。不然如果另一个线程它新put的key恰好就是我们要找寻的key,此时如果不从头遍历,那么就会出现key冗余的错误。

其实基本的逻辑就是这样,只是第一段逻辑e == null里还有一段比较迷惑的代码:

if (node == null) // speculatively create node
     node = new HashEntry<K,V>(hash, key, value, null);

如果最开始不理解整个方法的逻辑,可以直接先把这段代码删掉,那么就能很清晰地看出来三个if, else if之间的关系和作用了。这里的操作称为预测性地创建节点,后续在put方法里可以用到。实际上这里不创建,put方法里检测到node为null也会自行创建,但这里就“多余”地进行了预测性的创建。为什么?源码里的解释是这样的:Since traversal speed doesn’t matter, we might as well help warm up the associated code and accesses as well.

意思就是,遍历链表的速度并不重要,所以我们可以在可以预先做一些操作(创建node),这样就不用在后续的put方法里再进行创建了。为什么要这样做,我个人的猜测是,这一段代码消耗时间最长的并非遍历链表的时间,而是自旋的时间,所以提升链表遍历的时间意义不大。相反,在这里就预先创建好node,这样能提升put方法的效率,从而节省put方法里获取了锁之后的操作时间。

再提一下最后的判定条件,(retries & 1) == 0是什么意思。可以参考一下此链接的答案:https://stackoverflow.com/questions/25196851/concurrenthashmap-in-jdk7-code-explanation-scanandlockforput

但我觉得并没有给出具体的答案,为什么要偶数次才进行一次检测头节点是否发生了改变。显然,每一次retry都检测肯定是可行的,只是偶数次可以减少检测的次数,提高效率。那么为什么一定要每两次检测一次,为什么不会在中途出错?这个问题,其实我还没搞懂。// TODO!!!

总而言之,scanAndLockForPut方法可以确保当前线程获取到了锁,如果可以的话,会顺便把node也预先创建好。

rehash方法

当 put 方法时,发现元素个数超过了阈值,则会扩容。需要注意的是,每个Segment只管它自己的扩容,互相之间并不影响。换句话说,可以出现这个 Segment的长度为2,另一个Segment的长度为4的情况(都是2的n次幂)。如下代码:(可以先把lastRun相关的代码全部删掉,这样就很好理解了)

//node为创建的新节点
private void rehash(HashEntry<K,V> node) {
  //当前Segment中的旧表
  HashEntry<K,V>[] oldTable = table;
  //旧的容量
  int oldCapacity = oldTable.length;
  //新容量为旧容量的2倍
  int newCapacity = oldCapacity << 1;
  //更新新的阈值
  threshold = (int)(newCapacity * loadFactor);
  //用新的容量创建一个新的 HashEntry 数组
  HashEntry<K,V>[] newTable =
    (HashEntry<K,V>[]) new HashEntry[newCapacity];
  //当前的掩码,用于计算节点在新数组中的下标
  int sizeMask = newCapacity - 1;
  //遍历旧table数组
  for (int i = 0; i < oldCapacity ; i++) {
    HashEntry<K,V> e = oldTable[i];
    // 当前下标的链表/桶不为空
    if (e != null) {
      HashEntry<K,V> next = e.next;
      //计算hash值在新数组中的下标位置
      int idx = e.hash & sizeMask;
      //如果e不为空,且它的下一个节点为空,则说明这条链表只有一个节点,
      //直接把这个节点放到新数组的对应下标位置即可
      if (next == null)   //  Single node on list
        newTable[idx] = e;
      //否则,处理当前链表的节点迁移操作
      else { // Reuse consecutive sequence at same slot
        // lastRun节点表示链表最后映射到同一下标的几个连续节点的第一个
        HashEntry<K,V> lastRun = e;
        // lastRun节点对应的新下标就是lastIdx
        int lastIdx = idx;
        for (HashEntry<K,V> last = next;
             last != null;
             last = last.next) {
          //计算当前遍历到的节点的新下标
          int k = last.hash & sizeMask;
          //若 k 不等于 lastIdx,则说明此次遍历到的节点和上次遍历到的节点不在同一个下标位置
          //需要把 lastRun 和 lastIdx 更新为当前遍历到的节点和下标值。
          //若相同,则不处理,继续下一次 for 循环。
          if (k != lastIdx) {
            lastIdx = k;
            lastRun = last;
          }
        }
        //把和 lastRun 节点的下标位置相同的链表最末尾的几个连续的节点直接放到新数组的对应下标位置
        newTable[lastIdx] = lastRun;
        //再把剩余的节点,复制到新数组
        //从旧数组的头结点开始遍历,直到 lastRun 节点,因为 lastRun节点后边的节点都已经迁移完成了。
        for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
          V v = p.value;
          int h = p.hash;
          int k = h & sizeMask;
          HashEntry<K,V> n = newTable[k];
          //用的是复制节点信息的方式,并不是把原来的节点直接迁移,区别于lastRun处理方式
          newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
        }
      }
    }
  }
  //所有节点都迁移完成之后,再处理传进来的新的node节点,把它头插到对应的下标位置
  int nodeIndex = node.hash & sizeMask; // add the new node
  //头插node节点
  node.setNext(newTable[nodeIndex]);
  newTable[nodeIndex] = node;
  //更新当前Segment的table信息
  table = newTable;
}

看起来代码很长,实际上是因为它进行了优化,添加了lastRun这个概念。如果暂时不考虑lastRun这个节点,那么整体的逻辑实际上就这几行代码:

for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
  V v = p.value;
  int h = p.hash;
  int k = h & sizeMask;
  HashEntry<K,V> n = newTable[k];
  //用的是复制节点信息的方式,并不是把原来的节点直接迁移,区别于lastRun处理方式
  newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}

方法名虽然叫rehash,但实际上并没有重新计算hash值,只是把table的大小扩大为2倍,与此同时sizeMask也会发生变化,最后改变的只有下标index。而lastRun的含义是:找到链表最后一个新下标连续相同的节点。如图所示:

image-20200902075145192

桶数组里,每一个节点在新数组的下标不一定相同,但最后可能有连续几个是相同的,这时候只要记录这几个连续相同节点的第一个位置,记录为lastRun,此时只需要进行一次赋值操作:newTable[lastIdx] = lastRun;对于lastRun节点,直接赋值一个子链表。而剩余的节点是复制节点信息,从链表头开始,直到lastRun。对于上图的例子,首先会找到lastRun的位置,即倒数第三个元素。然后这时候newTable[lastIdx] = lastRun;就直接一次把lastRun开始,直到链表结尾的元素全部迁移完成了。最后再从头开始,处理前面的元素,即k2,k2,k1。

为什么要这样设计一个lastRun,毫无疑问也是为了优化,如果是普通的实现,直接从头开始一个一个进行newTable[k] = new HashEntry<K,V>(h, p.key, v, n);肯定也是可行的。这里看一下设计者的注释:

 /*
     * Reclassify nodes in each list to new Map.  Because we are
     * using power-of-two expansion, the elements from each bin
     * must either stay at same index, or move to
     * oldCapacity+index. We also eliminate unnecessary node
     * creation by catching cases where old nodes can be reused
     * because their next fields won't change. Statistically, at
     * the default threshhold, only about one-sixth of them need
     * cloning. (The nodes they replace will be garbage
     * collectable as soon as they are no longer referenced by any
     * reader thread that may be in the midst of traversing table
     * right now.)
     */

中文意思就是,因为规定了数组长度是2次幂,所以sizeMask实际上只是增加一个值为1的高位,最后新的下标只有两种情况,要么跟原坐标相等,要么移动到原坐标+旧的数组长度(oldCapacity + index),这实际上在JDK1.8里的HashMap就已经应用到了。lastRun就是为了减少不必要的节点创建,对于那些“next fields won’t change”的节点,即最后那些新坐标都相同的连续节点,直接赋值一次lastRun即可,而不是一个一个地创建。据统计,在默认的阈值下,大约只有1/6的节点需要被克隆。而原本的节点因为没有引用指向,所以很快就会被GC回收,不必担心内存泄漏的问题。当然,在最坏的情况下,lastRun没有任何优化,即lastRun为链表的最后一个元素或者很靠后的元素,但整体上的效率还是提升了,所以就保留了lastRun这个设计。(但这个优化效率有点夸张,为什么能达到这么好的效果,这里我不是很清楚~)

参考链接:

http://gee.cs.oswego.edu/dl/classes/EDU/oswego/cs/dl/util/concurrent/ConcurrentHashMap.java

get方法

get方法的逻辑相比put就简单很多了,先定位到Segment,再定位到HashEntry,over。

public V get(Object key) {
  Segment<K,V> s; // manually integrate access methods to reduce overhead
  HashEntry<K,V>[] tab;
  //计算hash值
  int h = hash(key);
  //同样的先定位到 key 所在的Segment ,然后从主内存中取出最新的节点
  long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
  if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
      (tab = s.table) != null) {
    //若Segment不为空,且链表也不为空,则遍历查找节点
    for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
         (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
         e != null; e = e.next) {
      K k;
      //找到则返回它的 value 值,否则返回 null
      if ((k = e.key) == key || (e.hash == h && key.equals(k)))
        return e.value;
    }
  }
  return null;
}

remove方法

remove方法与put类似,tryLock里调用的是scanAndLock方法,唯一的区别是无需预先创建节点,所以更加简单。

public V remove(Object key) {
  int hash = hash(key);
  //定位到Segment
  Segment<K,V> s = segmentForHash(hash);
  //若 s为空,则返回 null,否则执行 remove
  return s == null ? null : s.remove(key, hash, null);
}

public boolean remove(Object key, Object value) {
  int hash = hash(key);
  Segment<K,V> s;
  return value != null && (s = segmentForHash(hash)) != null &&
    s.remove(key, hash, value) != null;
}

final V remove(Object key, int hash, Object value) {
  //尝试加锁,若失败,则执行 scanAndLock ,此方法和 scanAndLockForPut 方法类似
  if (!tryLock())
    scanAndLock(key, hash);
  V oldValue = null;
  try {
    HashEntry<K,V>[] tab = table;
    int index = (tab.length - 1) & hash;
    //从主内存中获取对应 table 的最新的头结点
    HashEntry<K,V> e = entryAt(tab, index);
    HashEntry<K,V> pred = null;
    while (e != null) {
      K k;
      HashEntry<K,V> next = e.next;
      //匹配到 key
      if ((k = e.key) == key ||
          (e.hash == hash && key.equals(k))) {
        V v = e.value;
        // value 为空,或者 value 也匹配成功
        if (value == null || value == v || value.equals(v)) {
          if (pred == null)
            setEntryAt(tab, index, next);
          else
            pred.setNext(next);
          ++modCount;
          --count;
          oldValue = v;
        }
        break;
      }
      pred = e;
      e = next;
    }
  } finally {
    unlock();
  }
  return oldValue;
}

size方法

在多线程的环境下,统计size的时候可能不同Segment的数组元素也在不断变化,因此size方法和单线程不一样。这里采用的是乐观的做法,不加锁地去测试至多3次,使用一个last参数去记住上一次循环的值。如果在当前循环的时候,获得了sum == last,说明这一次循环和上一次循环中途没有出现并发,因此这个size就是正确的结果,直接返回。否则继续重试,直到Retry次数达到阈值,再给所有Segment加锁,再次统计准确的size。中间还有一个overflow的布尔值,用于记录结果是否int溢出。

public int size() {
  // Try a few times to get accurate count. On failure due to
  // continuous async changes in table, resort to locking.
  //segment数组
  final Segment<K,V>[] segments = this.segments;
  //统计所有Segment中元素的总个数
  int size;
  //如果size大小超过32位,则标记为溢出为true
  boolean overflow; 
  //统计每个Segment中的 modcount 之和
  long sum;         
  //上次记录的 sum 值
  long last = 0L;   
  //重试次数,初始化为 -1
  int retries = -1; 
  try {
    for (;;) {
      //如果超过重试次数,则不再重试,而是把所有Segment都加锁,再统计 size
      if (retries++ == RETRIES_BEFORE_LOCK) {
        for (int j = 0; j < segments.length; ++j)
          //强制加锁
          ensureSegment(j).lock(); // force creation
      }
      sum = 0L;
      size = 0;
      overflow = false;
      //遍历所有Segment
      for (int j = 0; j < segments.length; ++j) {
        Segment<K,V> seg = segmentAt(segments, j);
        //若当前遍历到的Segment不为空,则统计它的 modCount 和 count 元素个数
        if (seg != null) {
          //累加当前Segment的结构修改次数,如put,remove等操作都会影响modCount
          sum += seg.modCount;
          int c = seg.count;
          //若当前Segment的元素个数 c 小于0 或者 size 加上 c 的结果小于0,则认为溢出
          //因为若超过了 int 最大值,就会返回负数
          if (c < 0 || (size += c) < 0)
            overflow = true;
        }
      }
      //当此次尝试,统计的 sum 值和上次统计的值相同,则说明这段时间内,
      //并没有任何一个 Segment 的结构发生改变,就可以返回最后的统计结果
      if (sum == last)
        break;
      //不相等,则说明有 Segment 结构发生了改变,则记录最新的结构变化次数之和 sum,
      //并赋值给 last,用于下次重试的比较。
      last = sum;
    }
  } finally {
    //如果超过了指定重试次数,则说明表中的所有Segment都被加锁了,因此需要把它们都解锁
    if (retries > RETRIES_BEFORE_LOCK) {
      for (int j = 0; j < segments.length; ++j)
        segmentAt(segments, j).unlock();
    }
  }
  //若结果溢出,则返回 int 最大值,否则正常返回 size 值 
  return overflow ? Integer.MAX_VALUE : size;
}

代码if (sum == last) break;,这里的sum和last统计的是modCount,因为只要两次循环的modCount相同,中间肯定就没有任何并发修改导致的数据不一致问题,所以此时可以直接返回size作为最后的结果了。

JDK1.8的ConcurrentHashMap

JDK1.8的ConcurrentHashMap抛弃了分段锁的概念,利用CAS+synchronized来实现线程同步,底层采用和HashMap一样的数据结构,即数组+链表+红黑树。在过去Synchronized一直是重量级锁,但在JDK1.6开始,引入了偏向锁,轻量级锁,锁升级的概念,使得synchronized的效率已经可以媲美Lock锁,甚至是赶超。因此只要在HashMap的基础上使用CAS和synchronized进行改进,就能改造成优秀的多线程版本。

既然synchronized已经得到了优化,为什么Hashtable依然不能效率低下?原因有两个,其一是Hashtable没有红黑树的存在,其二是synchronized直接声明在get和put方法开头,因此即使synchronized效率升级,Hashtable整体的效率依然低下。

数据结构

ConcurrentHashMap有很多属性都与HashMap是相同的,新增加的属性有:

/* ---------------- Fields -------------- */

/**
 * The array of bins. Lazily initialized upon first insertion.
 * Size is always a power of two. Accessed directly by iterators.
 */
transient volatile Node<K,V>[] table;

/**
 * The next table to use; non-null only while resizing.
 */
private transient volatile Node<K,V>[] nextTable;
/**
 * Table initialization and resizing control.  When negative, the
 * table is being initialized or resized: -1 for initialization,
 * else -(1 + the number of active resizing threads).  Otherwise,
 * when table is null, holds the initial table size to use upon
 * creation, or 0 for default. After initialization, holds the
 * next element count value upon which to resize the table.
 */
private transient volatile int sizeCtl;
/**
 * The next table index (plus one) to split while resizing.
 */
private transient volatile int transferIndex;

/**
 * Spinlock (locked via CAS) used when resizing and/or creating CounterCells.
 */
private transient volatile int cellsBusy;

/**
 * Table of counter cells. When non-null, size is a power of 2.
 */
private transient volatile CounterCell[] counterCells;

因为是在多线程环境下使用,所以这些变量都用了volatile修饰,CAS+volatile实现。在HashMap里是没有nextTable这个变量的,因为默认在单线程环境下使用,于是新扩容的table直接在resize里创建。但多线程环境下,要确保多个线程同时resize时,只有一个成功创建并同步到内存里,所以nextTable也要设置为全局的volatile变量。

sizeCtl是一个状态量,用于控制table的初始化和扩容。不同的值有不同的含义,-1表示正在初始化,-N表示有N - 1个活跃的线程正在进行resize。0表示table还没有初始化,初始化完成后,sizeCtl会是一个正数。当数组为null时,sizeCtl代表数组初始化大小,当数组不为null时,代表数组的扩容阈值。后面三个参数会在put方法里具体介绍。

put方法

public V put(K key, V value) {
  return putVal(key, value, false);
}

final V putVal(K key, V value, boolean onlyIfAbsent) {
  //可以看到,在并发情况下,key 和 value 都是不支持为空的。
  if (key == null || value == null) throw new NullPointerException();
  //这里和1.8 HashMap 的hash 方法大同小异,只是多了一个操作,如下
  //( h ^ (h >>> 16)) & HASH_BITS;  HASH_BITS = 0x7fffffff;
  // 0x7fffffff ,二进制为 0111 1111 1111 1111 1111 1111 1111 1111 。
  //所以,hash值除了做了高低位异或运算,还多了一步,保证最高位的 1 个 bit 位总是0。
  //这里,我并没有明白它的意图,仅仅是保证计算出来的hash值不超过 Integer 最大值,且不为负数吗。
  //同 HashMap 的hash 方法对比一下,会发现连源码注释都是相同的,并没有多说明其它的。
  //我个人认为意义不大,因为最后 hash 是为了和 capacity -1 做与运算,而 capacity 最大值为 1<<30,
  //即 0100 0000 0000 0000 0000 0000 0000 0000 ,减1为 0011 1111 1111 1111 1111 1111 1111 1111。
  //即使 hash 最高位为 1(无所谓0),也不影响最后的结果,最高位也总会是0.
  int hash = spread(key.hashCode());
  //用来计算当前链表上的元素个数,
  int binCount = 0;
  for (Node<K,V>[] tab = table;;) {
    Node<K,V> f; int n, i, fh;
    //如果表为空,则说明还未初始化。
    if (tab == null || (n = tab.length) == 0)
      //初始化表,只有一个线程可以初始化成功。
      tab = initTable();
    //若表已经初始化,则找到当前 key 所在的桶,并且判断是否为空
    else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
      //若当前桶为空,则通过 CAS 原子操作,把新节点插入到此位置,
      //这保证了只有一个线程可以 CAS 成功,其它线程都会失败。
      if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value, null)))
        break;                   // no lock when adding to empty bin
    }
    //若所在桶不为空,则判断节点的 hash 值是否为 MOVED(值是-1)
    else if ((fh = f.hash) == MOVED)
      //若为-1,说明当前数组正在进行扩容,则需要当前线程帮忙迁移数据
      tab = helpTransfer(tab, f);
    else {
      V oldVal = null;
      //这里用加同步锁的方式,来保证线程安全,给桶中第一个节点对象加锁
      synchronized (f) {
        //recheck 一下,保证当前桶的第一个节点无变化,后边很多这样类似的操作,不再赘述
        if (tabAt(tab, i) == f) {
          //如果hash值大于等于0,说明是正常的链表结构
          if (fh >= 0) {		// fh (first hash value)
            binCount = 1;		// binCount记录链表的元素个数
            //从头结点开始遍历,每遍历一次,binCount计数加1
            for (Node<K,V> e = f;; ++binCount) {
              K ek;
              //如果找到了和当前 key 相同的节点,则用新值替换旧值
              if (e.hash == hash && ((ek = e.key) == key || (ek != null && key.equals(ek)))) {
                oldVal = e.val;
                if (!onlyIfAbsent)
                  e.val = value;
                break;
              }
              Node<K,V> pred = e;
              //若遍历到了尾结点,则把新节点尾插进去
              if ((e = e.next) == null) {
                pred.next = new Node<K,V>(hash, key,
                                          value, null);
                break;
              }
            }
          }
          //否则判断是否是树节点。这里提一下,TreeBin只是头结点对TreeNode的再封装
          else if (f instanceof TreeBin) {
            Node<K,V> p;
            binCount = 2;
            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                  value)) != null) {
              oldVal = p.val;
              if (!onlyIfAbsent)
                p.val = value;
            }
          }
        }
      }
      //注意下,这个判断是在同步锁外部,因为 treeifyBin内部也有同步锁,无需担心线程同步的问题
      if (binCount != 0) {
        //如果节点个数大于等于 8,则转化为红黑树
        if (binCount >= TREEIFY_THRESHOLD)
          treeifyBin(tab, i);
        //把旧节点值返回
        if (oldVal != null)
          return oldVal;
        break;
      }
    }
  }
  //给元素个数加 1,并有可能会触发扩容,比较复杂,稍后细讲,记住第二个参数是链表/树的元素个数
  addCount(1L, binCount);
  return null;
}

逐行来看:

第一行:if (key == null || value == null) throw new NullPointerException();

这也印证了前面的内容,ConcurrentHashMap不允许key或value为null值

第二行:int hash = spread(key.hashCode());

与HashMap是基本一致的,只是spread方法稍微有点不一样,如下:

static final int spread(int h) {
  return (h ^ (h >>> 16)) & HASH_BITS;
}

HASH_BITS的值为0x7fffffff,所以这个与操作只会导致一个结果:返回的hash值第一位必定为0,作用是使得hash值必定不会超过Integer的最大值,并且不为负数。看起来没什么意义,因为后续hash都是要和capacity - 1进行与操作的,而capacity的最大值为1 << 30。但后续我们会看到,ConcurrentHashMap里设置了三个特殊的负数hash值,分别代表不同的含义,因此普通的节点就设置为非负数了。

第三行:for (Node<K,V>[] tab = table;;)

注意,这里根本就不是table的遍历,不要看到for循环就习惯性地觉得,啊这里是对table的遍历。table是一个哈希数组,这里只是一个赋值+无限循环,我们后续会根据hash值定位到具体的下标,也就是具体的Node<K, V>。要明确这里的含义,可以和HashMap的put方法进行对比一下,就清晰了。

第四行:

if (tab == null || (n = tab.length) == 0)
  //初始化表,只有一个线程可以初始化成功。
  tab = initTable();

与HashMap同理,在构造方法里并不会初始化table,而是在第一次put操作的时候,检测到table为null,才会进行对table的初始化(懒汉式加载)。initTable方法先放着,后面再看它的代码。

第五行:

//若表已经初始化,则找到当前 key 所在的桶,并且判断是否为空
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
  //若当前桶为空,则通过 CAS 原子操作,把新节点插入到此位置,
  //这保证了只有一个线程可以 CAS 成功,其它线程都会失败。
  if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value, null)))
    break;                   // no lock when adding to empty bin
}

代码(f = tabAt(tab, i = (n - 1) & hash)),这里就是根据hash值定位到具体的bucket桶中。这时候就是要把数据put到该桶中,首先判断该桶是否为空,如果空,那么就进行CAS自旋尝试put操作。只要put成功,那么此处就执行break跳出for循环,表明put操作完成。同时CAS也保证了只有一个线程可以CAS成功,避免因为线程同步导致的更新丢失问题。

第六行:

//若所在桶不为空,则判断节点的 hash 值是否为 MOVED(值是-1)
else if ((fh = f.hash) == MOVED)
  //若为-1,说明当前数组正在进行扩容,则需要当前线程帮忙迁移数据
  tab = helpTransfer(tab, f);

执行到这里,说明表已经初始化,而且bucket也不为null,但此时并不是直接开始遍历该bucket/Node/链表,而是要先判断table是否正在进行扩容。如果是,那么调用helpTransfer方法帮助扩容。避免其他线程的put操作引发了resize,此时当前线程也应该去帮忙扩容,继续CAS自旋进行put操作,只会一直失败(其他线程resize会改变头节点,index也会发生改变,从而put操作无法执行完成,或者执行的结果会是错误的)

第七行:

else {
  V oldVal = null;
  // 这里用加同步锁的方式,来保证线程安全,给桶中第一个节点对象加锁
  synchronized (f) {
    // recheck 一下,保证当前桶的第一个节点无变化,后边很多这样类似的操作,不再赘述
    if (tabAt(tab, i) == f) {
      //如果hash值大于等于0,说明是正常的链表结构
      if (fh >= 0) {
        binCount = 1;
        //从头结点开始遍历,每遍历一次,binCount计数加1
        for (Node<K,V> e = f;; ++binCount) {
          K ek;
          //如果找到了和当前 key 相同的节点,则用新值替换旧值
          if (e.hash == hash && ((ek = e.key) == key || (ek != null && key.equals(ek)))) {
            oldVal = e.val;
            if (!onlyIfAbsent)
              e.val = value;
            break;
          }
          Node<K,V> pred = e;
          //若遍历到了尾结点,则把新节点尾插进去
          if ((e = e.next) == null) {
            pred.next = new Node<K,V>(hash, key, value, null);
            break;
          }
        }
      }
      else if (f instanceof TreeBin) {
        // ...
      }
    }
  }

执行到这里,说明table已经初始化,bucket也不会null,也没有处于resize状态,那么就跟HashMap的时候一样,遍历链表(bucket)。如果找到了相同的key节点,则替换掉旧值,返回旧值oldValue。如果遍历到了链表尾部也没有找到相同的key节点,那么就直接把新节点用尾插法进行插入,此时oldValue为null。但在这之前,首先要进行一次recheck,查看bucket是否发生了变化,如果发生了变化,那么此次synchronized加锁就没有意义,因为当前的数据是dirty data。所以会直接跳出synchronized块,然后再次进行一次for循环,获取到了新的f值再尝试用synchronized加锁。当我们确定了bucket没有在中途被修改后,还要先判断一下当前的节点是链表节点还是树节点。

这里我纠结了很久,为什么fh >= 0就是链表了?我想到了spread方法里把第一位置为0,使得hash必定为非负数的操作。那么fh小于0说明是什么呢?如下:

/*
 * Encodings for Node hash fields. See above for explanation.
 */
static final int MOVED     = -1; // hash for forwarding nodes
static final int TREEBIN   = -2; // hash for roots of trees
static final int RESERVED  = -3; // hash for transient reservations
static final int HASH_BITS = 0x7fffffff; // usable bits of normal node hash

我们前面已经看到了MOVED表示要扩容,而-2表示红黑树的根节点,-3表示transient reservations(不知道是啥),所以当fh小于0的时候,它还不一定是红黑树节点,所以仍然要再使用instanceof来判断是否为TreeBin。红黑树的部分依然先pass,后面再补。。。

第八行:

//注意下,这个判断是在同步锁外部,因为 treeifyBin内部也有同步锁,并不影响
if (binCount != 0) {
  //如果节点个数大于等于 8,则转化为红黑树
  if (binCount >= TREEIFY_THRESHOLD)
    treeifyBin(tab, i);
  //把旧节点值返回
  if (oldVal != null)
    return oldVal;
  break;
}

如果binCount为0,说明没有进入synchronized体里,继续自旋。如果不为0,说明已经执行了synchronized内的代码,判断是否需要树化,再判断oldVal是否为null。如果oldVal不为null,说明在bucket里找到了相同的key并且进行了value的替代,此时直接返回oldVal。否则用break跳出for循环,执行到最后一行代码。

第九行:

//给元素个数加 1,并有可能会触发扩容,比较复杂,稍后细讲
addCount(1L, binCount);
return null;

最后就是调用addCount,使得元素的个数加1。在单线程环境下很简单,但在多线程下挺复杂的,后面再讲。

总而言之,这就是put的基本流程:

①如果table还没有初始化,就先调用initTable进行初始化

②如果table已经初始化,定位到相应的Node,查看Node是否为null。如果Node为null,那么就不存在hash冲突,直接CAS自旋。

③如果Node不为null,那么要先根据桶头节点f的hash值判断table是否处于扩容状态,如果是,调用helpTransfer

④如果没有处于扩容状态,根据头节点f的hash值,区分到底是链表节点还是树节点,然后进行遍历。遍历过程中如果找到相同key的节点,则进行value的更新即可。如果到达尾部还没有找到,进行尾插法。同时根据oldValue是否为null,决定是否要修改元素的个数,即是否会到达addCount方法。

initTable方法

private final Node<K,V>[] initTable() {
  Node<K,V>[] tab; int sc;
  //循环判断表是否为空,直到初始化成功为止。
  while ((tab = table) == null || tab.length == 0) {
    //sizeCtl 这个值有很多情况,默认值为0,
    //当为 -1 时,说明有其它线程正在对表进行初始化操作
    //当表初始化成功后,又会把它设置为扩容阈值
    //当为一个小于 -1 的负数,用来表示当前有几个线程正在帮助扩容(后边细讲)
    if ((sc = sizeCtl) < 0)
      //若 sc 小于0,其实在这里就是-1,因为此时表是空的,不会发生扩容
      //因此,当前线程放弃 CPU 时间片,只是自旋。
      Thread.yield(); // lost initialization race; just spin
    //通过 CAS 把 SIZECTL 的值设置为-1,表明当前线程正在进行表的初始化,其它失败的线程就会自旋
    else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
      try {
        //重新检查一下表是否为空
        if ((tab = table) == null || tab.length == 0) {
          //如果sc大于0,则为sc,否则返回默认容量 16。
          //当调用有参构造创建 Map 时,sc的值是大于0的。
          int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
          @SuppressWarnings("unchecked")
          //创建数组
          Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
          table = tab = nt;
          //n减去 1/4 n ,即为 0.75n ,表示扩容阈值
          sc = n - (n >>> 2);
        }
      } finally {
        //更新 sizeCtl 为扩容阈值
        sizeCtl = sc;
      }
      //若当前线程初始化表成功,则跳出循环。其它自旋的线程因为判断数组不为空,也会停止自旋
      break;
    }
  }
  return tab;
}

initTable的逻辑并不难,主要就是通过sizeCtl来控制哪个线程进行初始化。各个线程进行CAS操作,如果成功执行完CAS,那么此时sizeCtl的值就从0置为-1,其他线程就会自旋(一直在while和yield)。而对于成功执行完CAS的线程,则会进入try语句块。首先还是要进行一次recheck,因为可能其他线程刚创建完,并且已经CAS执行完更新,这时候当前线程才刚刚把SIZECTL修改,进入try块,如果重复创建就浪费资源了。recheck之后如果table依然为null,那么就进行创建:Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];。这里要提一下Unsafe里的compareAndSwapInt方法,它的定义如下:

/**
  * Atomically update Java variable to <tt>x</tt> if it is currently
  * holding <tt>expected</tt>.
  * @return <tt>true</tt> if successful
  */
public final native boolean compareAndSwapInt(Object o, long offset,
                                              int expected,
                                              int x);

如果offset的值为expected,那么把x原子性地赋值给offset,返回true,反之赋值失败,返回false。

所以对于这一行代码else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)),只有一个线程能成功获取到乐观锁,将SIZECTL赋值为-1并进入try块,而其他线程会因为SIZECTL已经被赋值为1,一直yield,也就是自旋。对于进入了try块的线程,此时sc并不为-1,如果当使用有参构造器创建Map的时候,此时sc代表initialCapacity,sc最后会记录table的扩容阈值:sc = n - (n >>> 2);。最后再更新sizeCtl的值,注意这个sizeCtl是Map里的一个volatile变量。之后其他线程会结束自旋,进入try块,但在double-check里发现Map已经构造完成,因此不会重复创建。

(initTable的CAS操作,关键不要混淆sizeCtl和sc,我一开始搞混了,还在纳闷为什么sc会是正数,不是应该稳定-1吗)

addCount方法

addCount的目标很明确,就是要改变整个Map的元素个数,即一次put,使得size加一。对于HashMap来说,单线程环境下直接修改size变量即可。而对于ConcurrentHashMap的多线程环境下,很容易想到,把size修改为volatile,保证可见性,然后使用CAS进行自增。这确实可行,但JDK里并不是这么实现的。

当有多个线程进行了put,此时size变量就会造成很严重的竞争,直接volatile+CAS的效率很低。如果不考虑乐观锁,直接转为悲观锁,那么效率就更低了。JDK里是把每一个竞争的线程分散到不同的对象里,在该对象里单独计算每一个线程的size变化,最后要统计size的时候再一起相加。这个思想有点像1.7分段锁时的做法,但1.8已经放弃了分段锁,所以这里的对象(CounterCell)仅仅是用作计算size。如下:

//线程被分配到的格子
@sun.misc.Contended static final class CounterCell {
  //此格子内记录的 value 值
  volatile long value;
  CounterCell(long x) { value = x; }
}

//用来存储线程和线程生成的随机数的对应关系
static final int getProbe() {
  return UNSAFE.getInt(Thread.currentThread(), PROBE);
}

CounterCell的定义仅仅只有一个volatile变量,只为了计数的小单元(Cell)。同时这个类添加了注解:@sun.misc.Contended,这是一个避免伪造共享的注解,用于替代以前的缓存行填充,在多线程环境下可以提高性能。

而getProbe是给当前线程生成一个随机数,可以简单地理解为生成了一个hash值,后续要用来和数组长度取模,计算它所在CounterCells数组的下标位置。

在addCount方法里,baseCount就是size,而CounterCell,cellsBusy,cellValue等都是辅助变量。

// x为1,check代表链表上的元素个数
private final void addCount(long x, int check) {
 CounterCell[] as; long b, s;
 //此处要进入if有两种情况
 //1.数组不为空,说明数组已经被创建好了。
 //2.若数组为空,说明数组还未创建,很有可能竞争的线程非常少,因此就直接 CAS 操作 baseCount
 //若 CAS 成功,则方法跳转到 (2)处,若失败,则需要考虑给当前线程分配一个格子(指CounterCell对象)
 if ((as = counterCells) != null ||
  !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
  CounterCell a; long v; int m;
  //字面意思,是无竞争,这里先标记为 true,表示还没有产生线程竞争
  boolean uncontended = true;
  //这里有三种情况,会进入 fullAddCount 方法
  //1.若数组为空,进方法 (1)
  //2.ThreadLocalRandom.getProbe() 方法会给当前线程生成一个随机数(可以简单的认为也是一个hash值)
  //然后用随机数与数组长度取模,计算它所在的格子。若当前线程所分配到的格子为空,进方法 (1)。
  //3.若数组不为空,且线程所在格子不为空,则尝试 CAS 修改此格子对应的 value 值加1。
  //若修改成功,则跳转到 (3),若失败,则把 uncontended 值设为 fasle,说明产生了竞争,然后进方法 (1)
  if (as == null || (m = as.length - 1) < 0 ||
   (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
   !(uncontended =
     U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
   //方法(1), 这个方法的目的是让当前线程一定把 1 加成功。情况更多,更复杂,稍后讲。
   fullAddCount(x, uncontended);
   return;
  }
  // (3)能走到这,说明数组不为空,且修改 baseCount失败,
  // 且线程被分配到的格子不为空,且修改 value 成功。
  // check参数是在putVal里的binCount
  // binCount == 1说明是链表,且替换了head节点的val值
  // 或者是数组对应下标的链表为空,然后添加了新的head节点
  // 无论是哪一种,都不需要扩容
  if (check <= 1)
   return;
  //计算总共的元素个数
  s = sumCount();
 }
 //(2)这里用于检查是否需要扩容(先跳过,先看后面的transfer方法
 if (check >= 0) {
  Node<K,V>[] tab, nt; int n, sc;
  //若元素个数达到扩容阈值(即map添加的节点数大于等于sizeCtl
  //且tab不为空,且tab数组长度小于最大容量
  while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
      (n = tab.length) < MAXIMUM_CAPACITY) {
   //这里假设数组长度n就为16,这个方法返回的是一个固定值rs,用于当做一个扩容的校验标识
   int rs = resizeStamp(n);
   //若sc小于0,说明正在扩容
   if (sc < 0) {
		// 此处有bug,我们只需要知道:当前桶数组正在扩容,但当前线程无需帮助扩容
    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
     sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
     transferIndex <= 0)
     break;
    //到这里说明当前线程可以帮助扩容,因此sc值加一,代表扩容的线程数加1
    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
     transfer(tab, nt);
   }
   // 当sc大于0,说明sc代表扩容阈值,因此第一次扩容之前肯定走这个分支,用于初始化新表 nextTable
   // 此时会把sc赋值为(rs << RESIZE_STAMP_SHIFT) + 2,是一个标记值,表示首个帮助扩容的线程
   else if (U.compareAndSwapInt(this, SIZECTL, sc,
           (rs << RESIZE_STAMP_SHIFT) + 2))
    //扩容,第二个参数代表新表,传入null,则说明是第一次初始化新表(nextTable)
    transfer(tab, null);
   s = sumCount();
  }
 }
}

//计算表中的元素总个数
final long sumCount() {
 CounterCell[] as = counterCells; CounterCell a;
 //baseCount,以这个值作为累加基准
 long sum = baseCount;
 if (as != null) {
  //遍历 counterCells 数组,得到每个对象中的value值
  for (int i = 0; i < as.length; ++i) {
   if ((a = as[i]) != null)
    //累加 value 值
    sum += a.value;
  }
 }
 //此时得到的就是元素总个数
 return sum;
} 

//扩容时的校验标识(先跳过
static final int resizeStamp(int n) {
 return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}

addCount方法判断具体的情况,然后做出相应的应对方法。

①:当数组为null,此时直接进入第一个if体。如果数组不为null,此时数组还没有创建,那么很可能竞争的线程比较少,因此直接CAS尝试操作baseCount。如果CAS成功,那么说明baseCount已经成功自增,接下来就是跳转到检查是否扩容的部分。如果CAS失败,那么也进入第一个if体。第一个if体里依然会分不同的情况做出不同的选择。

②:进入第一个if体之后,如果CounterCells数组为null,说明在①的时候CAS执行失败,即遇到了并发问题,此时会进入fullAddCount方法。如果CounterCells数组不为null,但相应下标的对象为null(a = as[ThreadLocalRandom.getProbe() & m]),此时也会进入fullAddCount方法。如果相应下标的对象不为null,就尝试CAS修改此格子对应的value,如果修改成功,那么跳转到检查是否扩容的部分。如果CAS失败,同样是调用fullAddCount方法。直到这里可以看到,①和②主要是对于并发情况可能不太严重的时候直接进行一次CAS,即使有并发问题,但只要情况不算严重,那么CAS的成本并不算高,可如果CAS成功了就能节省一大笔资源。如果CAS失败,会跳转到fullAddCount方法,此方法就是使用CounterCells数组,单独计算每一个线程对应格子的值,最后再进行相加得到size(baseCount)。

③:扩容部分我们先跳过,查看fullAddCount方法。

fullAddCount方法

fullAddCount方法名的意思,要全力增加计算值,一定要成功。可能有部分已经在前面CAS成功了,但我们要确保全部线程都已经修改为正确的值,因此要对剩下这部分进入了fullAddCount方法对线程单独地修改它的值。

//传过来的参数分别为 1 , false
private final void fullAddCount(long x, boolean wasUncontended) {
  int h;
  //如果当前线程的随机数为0,则强制初始化一个值
  if ((h = ThreadLocalRandom.getProbe()) == 0) {
    ThreadLocalRandom.localInit();      // force initialization
    h = ThreadLocalRandom.getProbe();
    //此时把 wasUncontended 设为true,认为无竞争
    wasUncontended = true;
  }
  //用来表示比 contend(竞争)更严重的碰撞,若为true,表示可能需要扩容,以减少碰撞冲突
  boolean collide = false;                // True if last slot nonempty
  //循环内,外层if判断分三种情况,内层判断又分为六种情况
  for (;;) {
    CounterCell[] as; CounterCell a; int n; long v;
    //1. 若counterCells数组不为空。  建议先看下边的2和3两种情况,再回头看这个。 
    if ((as = counterCells) != null && (n = as.length) > 0) {
      // (1) 若当前线程所在的格子(CounterCell对象)为空
      if ((a = as[(n - 1) & h]) == null) {
        if (cellsBusy == 0) {    
          //若无锁,则乐观的创建一个 CounterCell 对象。
          CounterCell r = new CounterCell(x); 
          //尝试加锁
          if (cellsBusy == 0 &&
              U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
            boolean created = false;
            //加锁成功后,再 recheck 一下数组是否不为空,且当前格子为空
            try {               
              CounterCell[] rs; int m, j;
              if ((rs = counterCells) != null &&
                  (m = rs.length) > 0 &&
                  rs[j = (m - 1) & h] == null) {
                //把新创建的对象赋值给当前格子
                rs[j] = r;
                created = true;
              }
            } finally {
              //手动释放锁
              cellsBusy = 0;
            }
            //若当前格子创建成功,且上边的赋值成功,则说明加1成功,退出循环
            if (created)
              break;
            //否则,继续下次循环
            continue;           // Slot is now non-empty
          }
        }
        //若cellsBusy=1,说明有其它线程抢锁成功。或者若抢锁的 CAS 操作失败,都会走到这里,
        //则当前线程需跳转到(9)重新生成随机数,进行下次循环判断。
        collide = false;
      }
      /**
   *后边这几种情况,都是数组和当前随机到的格子都不为空的情况。
   *且注意每种情况,若执行成功,且不break,continue,则都会执行(9),重新生成随机数,进入下次循环判断
   */
      // (2) 到这,说明当前方法在被调用之前已经 CAS 失败过一次,若不明白可回头看下 addCount 方法,
      //为了减少竞争,则跳转到⑨处重新生成随机数,并把 wasUncontended 设置为true ,认为下一次不会产生竞争
      else if (!wasUncontended)       // CAS already known to fail
        wasUncontended = true;      // Continue after rehash
      // (3) 若 wasUncontended 为 true 无竞争,则尝试一次 CAS。若成功,则结束循环,若失败则判断后边的 (4)(5)(6)。
      else if (U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))
        break;
      // (4) 结合 (6) 一起看,(4)(5)(6)都是 wasUncontended=true,且CAS修改value失败的情况。
      //若数组有变化,或者数组长度大于等于当前CPU的核心数,则把 collide 改为 false
      //因为数组若有变化,说明是由扩容引起的;长度超限,则说明已经无法扩容,只能认为无碰撞。
      //这里很有意思,认真思考一下,当扩容超限后,则会达到一个平衡,即 (4)(5) 反复执行,直到 (3) 中CAS成功,跳出循环。
      else if (counterCells != as || n >= NCPU)
        collide = false;            // At max size or stale
      // (5) 若数组无变化,且数组长度小于CPU核心数时,且 collide 为 false,就把它改为 true,说明下次循环可能需要扩容
      else if (!collide)
        collide = true;
      // (6) 若数组无变化,且数组长度小于CPU核心数时,且 collide 为 true,说明冲突比较严重,需要扩容了。
      else if (cellsBusy == 0 &&
               U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
        try {
          //recheck
          if (counterCells == as) {// Expand table unless stale
            //创建一个容量为原来两倍的数组
            CounterCell[] rs = new CounterCell[n << 1];
            //转移旧数组的值
            for (int i = 0; i < n; ++i)
              rs[i] = as[i];
            //更新数组
            counterCells = rs;
          }
        } finally {
          cellsBusy = 0;
        }
        //认为扩容后,下次不会产生冲突了,和(4)处逻辑照应
        collide = false;
        //当次扩容后,就不需要重新生成随机数了
        continue;                   // Retry with expanded table
      }
      // (9),重新生成一个随机数,进行下一次循环判断
      h = ThreadLocalRandom.advanceProbe(h);
    }
    //2.这里的 cellsBusy 参数非常有意思,是一个volatile的 int值,用来表示自旋锁的标志,
    //可以类比 AQS 中的 state 参数,用来控制锁之间的竞争,并且是独占模式。简化版的AQS。
    //cellsBusy 若为0,说明无锁,线程都可以抢锁,若为1,表示已经有线程拿到了锁,则其它线程不能抢锁。
    else if (cellsBusy == 0 && counterCells == as &&
             U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
      boolean init = false;
      try {    
        //这里再重新检测下 counterCells 数组引用是否有变化
        if (counterCells == as) {
          //初始化一个长度为 2 的CounterCell数组
          CounterCell[] rs = new CounterCell[2];
          //根据当前线程的随机数值,计算下标,只有两个结果 0 或 1,并初始化对象
          rs[h & 1] = new CounterCell(x);
          //更新数组引用
          counterCells = rs;
          //初始化成功的标志
          init = true;
        }
      } finally {
        //别忘了,需要手动解锁。
        cellsBusy = 0;
      }
      //若初始化成功,则说明当前加1的操作也已经完成了,则退出整个循环。
      if (init)
        break;
    }
    //3.到这,说明数组为空,且 2 抢锁失败,则尝试直接去修改 baseCount 的值,
    //若成功,也说明加1操作成功,则退出循环。
    else if (U.compareAndSwapLong(this, BASECOUNT, v = baseCount, v + x))
      break;                          // Fall back on using base
  }
}

首先开头判断线程getProbe返回的随机数是否为0,如果是0,那么要调用localInit再一次进行初始化。因为在ThreadLocalRandom类中,0代表特殊的含义,表示未初始化

之后进入循环体for(;;)里,根据不同情况做出不同选择。主要用到了4个变量:

cellsBusy:一个volatile的int值,用于表示自旋锁的标志,可以类比为AQS中的state参数,用来控制锁之间的竞争,并且是独占模式,简化版的AQS。cellsBusy如果为0,表示无锁,线程可以进行抢占。如果为1,表示已经有线程拿到了锁,其他线程进入自旋。

cellValue:cellValue顾名思义,就是CounterCell里的值,所以fullAddCount方法最终需要做的就是修改cellValue的值。

wasUncontended:一个boolean值,表示是否发生了竞争,默认是true,表示无竞争。竞争是指CounterCell对象为null,但CAS创建的时候发生冲突

collide:一个boolean值,表示是否发生了碰撞,情况比contend要严重。如果为true,表示可能需要扩容,以减少碰撞冲突。碰撞时指CounterCell对象不为null,但CAS修改时发生冲突。需要区分wasUncontended和collide。

第一个if体里的6种情况:

①如果CounterCells数组不为null,并且线程随机到的下标格子(CounterCell对象)为null。

此时直接创建了一个CounterCell对象,然后尝试加锁,把新创建的CounterCell对象赋值到对应的格子。这里是先创建,再尝试加锁,是一种比较乐观的状态。如果加锁成功,那么格子赋值成功,已经到达了+1的效果,此时可以直接break退出循环了(原本是0,创建CounterCell的时候传入int值,表示要修改的值即可,一般是1),最后再在finally里释放锁即可。如果加锁失败,那么会回到循环的开始,但值得注意的是,collide依然会赋值为false,因为CounterCell对象为null,说明不存在冲突问题,此时只是CAS自旋而已,collide状态量的官方注释是:True if last slot nonempty

后面的几种情况,都是CounterCells数组不为null,且线程随机到的下标格子也不为null

else if (!wasUncontended),表明当前方法在被调用之前已经CAS失败过一次。为了减少竞争,此时会跳转到最后重新生成随机数,并且把wasUncontended设置为true(都已经非空了,wasUncontended是CAS创建时冲突,因此wasUncontended已经不会再是限制)

else if (U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x)),执行到这里,说明wasUncontended为true,无竞争,那么会尝试一次CAS(执行完②之后,①跟②的情况都不会再出现)。如果成功,break结束循环,否则继续后面的④⑤⑥几种情况,④⑤⑥都表示wasUncontended为true,且在③中都CAS修改失败的情况。

else if (counterCells != as || n >= NCPU),如果数组发生了变化,或者数组的长度大于等于当前CPU的核心数,则把collide改为false。数组发生了变化,说明此时的数据是stale的,那么就直接把状态量reset之后重试。如果数组长度大于等于CPU核心数,那么此时的数组长度已经到达了最大值,就不能继续扩容了。无论是哪种情况,collide都应该设置为false,表示无冲突。一种是数据已经改变,无法判断是否冲突,另一种是数据已经没有办法再扩容了,而collide就是用于扩容的,只能设置为false禁止继续扩容。

else if (!collide),如果数组无变化,且数组长度小于CPU核心数时,且collide为false,把它改为true。和④完全相反,此时表明CounterCells的数据没有发生变化,是实时的可用数据,并且n小于NCPU,但③的CAS依然失败,此时就是发生了碰撞。从②开始,CounterCell对象就不为null,并且CAS失败,而且也不是情况④中的特殊情况,因此此时发生碰撞就直接修改collide值,以求扩容

else if (cellsBusy == 0 && U.compareAndSwapInt(this, CELLSBUSY, 0, 1)),如果数组没有变化,且数组长度小于CPU核心数,且collide为true,说明冲突比较严重,需要扩容了。

④⑤⑥是联合起来的三种情况。如果发生碰撞,那么就调用⑥来扩容。如果已经扩容到达上限仍然一直碰撞,那只能④和⑤反复执行,直到③的CAS成功,跳出循环。还有一些细节是④和⑤无需continue,此时它会执行到h = ThreadLocalRandom.advanceProbe(h);,重新生成一个随机数,而情况⑥扩容之后,就不需要重新生成随机数了。

第一个if体的6种情况的逻辑大概就是这样,看起来很复杂其实也是循序渐进。情况①先处理CounterCell对象为null的情况,那么直接CAS创建。情况②表示当前方法已经用CAS失败过一次(要么是addCount方法里,要么是情况①时候的失败),但此时终究已经不符合①,因此认为不会再发生竞争(不会再有CAS创建的冲突),但会发生碰撞。情况③就是CAS自旋,尝试修改cellValue。情况④⑤⑥就是考虑发生碰撞后是否需要扩容。(记住wasUncontended和collide的区别

至此,第一个if体的6种情况已经整理完毕,接下来是剩下的两个else if的情况。这两个相比而言就简单很多,因为前面都是代表CounterCells数组非空的情况,所以剩下的两个判定,一个用于抢占锁,初始化CounterCells数组,并且就和前面的情况①一样,可以在赋值的时候就设置好value。

else if (cellsBusy == 0 && counterCells == as &&
         U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
  boolean init = false;
  try {    
    //这里再重新检测下 counterCells 数组引用是否有变化
    if (counterCells == as) {
      //初始化一个长度为 2 的数组
      CounterCell[] rs = new CounterCell[2];
      //根据当前线程的随机数值,计算下标,只有两个结果 0 或 1,并初始化对象
      rs[h & 1] = new CounterCell(x);
      //更新数组引用
      counterCells = rs;
      //初始化成功的标志
      init = true;
    }
  } finally {
    //别忘了,需要手动解锁。
    cellsBusy = 0;
  }
  //若初始化成功,则说明当前加1的操作也已经完成了,则退出整个循环。
  if (init)
    break;
}

如果抢占锁失败,创建CounterCells失败,那么就直接尝试去修改baseCount的值,一般我们自己设计可能就直接用这一步了,然后高并发的情况下竞争会导致效率极其低下。但JDK里是在前面设置好了很多种情况,到达这里只是尝试一下,“最后的倔强”:

else if (U.compareAndSwapLong(this, BASECOUNT, v = baseCount, v + x))
  break;                          // Fall back on using base

transfer方法

前面的putVal方法和addCount方法,我们看到中间可能会调用两个方法,transfer和helpTransfer,目标是用于扩容,其实更准确的说法应该是帮助迁移元素。创建一个双倍长度的table很简单,更关键的是把原本table的数据全部迁移到新的nextTable里。在单线程环境下的HashMap是如何做的?直接在需要扩容的时候调用resize方法即可。但多线程环境下的ConcurrentHashMap,会有多个线程同时put,然后新put的数据应该如何添加到扩容后的nextTable里呢?最简单的做法依然是直接加锁,或者volatile+CAS自旋,但显然二者都效率低下。JDK里的做法是,一个线程正在处于put,为什么要把这些新的数据留到后面再一起转移呢,可以每一个线程单独地加入到扩容方法里,帮助数据转移(helpTransfer),这样既实现了多个线程快速进行扩容,同时无需重新创建线程,直接利用现有的线程进行put+transfer。helpTransfer实际上也是判定是否可以进入transfer帮助转移数据,所以先来看transfer方法。

数据迁移的示意图:

image-20200904102237529

每一个新加入的线程都会从length - 1开始帮助迁移数据,也就是从数组的最后。每一个线程负责的数据也是有范围的,由变量stride指定,即一个线程在一次transfer中只负责stride个数据(这里一个数据就是数组上的一个Node,即一个链表/桶)。这样可以使得每一个线程都负责定量的数据,避免单个线程承担太多。对于已经迁移完成的数据,会标记为ForwardingNode,表示该数据(数组中的某个下标位置)中的元素(链表里的所有节点)已经全部迁移完毕,此时新加入的线程就不会迁移这些数据,而是继续向前推进(advance),寻找其他可以迁移的数据。如上图所示,我们假设stride为2(实际上默认是16,但这里便于演示)。A线程是第一个加入transfer的,那么它负责的就是index为7的数据和index为6的数据,一共负责stride(2)个数据。B线程随后加入transfer,发现index7的数据已经有其他线程在帮助迁移,那么就会继续向前寻找,而且也不是i–,而是i -= stride。因此B线程负责的是index为5和index为4的数据。可以看到,当一个位置的数据已经被处理完毕,或者正在被其他线程帮忙处理,此时新加入的线程都直接向前推进,寻找下一个需要帮助迁移的数据。

如果A线程把index6和index7的数据迁移完毕,它会继续向前帮忙,而它检测到index4和index5有其他线程在帮忙,因此它会帮忙index为2和index为3的数据。即每个线程迁移完它负责范围内的数据,都会继续向前推进。终止条件用到了一个全局变量transferIndex,来表示所有线程总共推进到的元素下标位置。对于上图的例子,此时transferIndex为2。直到transferIndex为0,说明每一个数据都已经有线程在帮忙处理了。但transferIndex为0只能说明所有数据都有线程在迁移,但不能说明所有数据都已经迁移完毕,所以后续还有其他校验来判断是否所有数据迁移完毕。先看代码:

//这个类是一个标志,用来代表当前桶(数组中的某个下标位置)的元素已经全部迁移完成
static final class ForwardingNode<K,V> extends Node<K,V> {
  final Node<K,V>[] nextTable;
  ForwardingNode(Node<K,V>[] tab) {
    //把当前桶的头结点的 hash 值设置为 -1,表明已经迁移完成,
    //这个节点中并不存储有效的数据
    super(MOVED, null, null, null);
    this.nextTable = tab;
  }
}

//迁移数据
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
  int n = tab.length, stride;
  //根据当前CPU核心数,确定每次推进的步长,最小值为16.(为了方便我们以2为例)
  if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
    stride = MIN_TRANSFER_STRIDE; // subdivide range	 此处默认是16
  //从 addCount 方法,只会有一个线程跳转到这里,初始化新数组
  if (nextTab == null) {            // initiating
    try {
      @SuppressWarnings("unchecked")
      //新数组长度为原数组的两倍
      Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
      nextTab = nt;
    } catch (Throwable ex) {      // try to cope with OOME
      sizeCtl = Integer.MAX_VALUE;
      return;
    }
    //用 nextTable 指代新数组
    nextTable = nextTab;
    //这里就把推进的下标值初始化为原数组长度(以16为例)
    transferIndex = n;
  }
  //新数组长度
  int nextn = nextTab.length;
  //创建一个标志类
  ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
  //是否向前推进的标志
  boolean advance = true;
  //是否所有线程都全部迁移完成的标志
  boolean finishing = false; // to ensure sweep before committing nextTab
  //i 代表当前线程正在迁移的桶的下标,bound代表它本次可以迁移的范围下限
  for (int i = 0, bound = 0;;) {
    Node<K,V> f; int fh;
    //需要向前推进
    while (advance) {
      int nextIndex, nextBound;
      //(1) 先看 (3) 。i每次自减 1,直到 bound。若超过bound范围,或者finishing标志为true,则不用向前推进。
      //若未全部完成迁移,且 i 并未走到 bound,则跳转到 (7),处理当前桶的元素迁移。
      if (--i >= bound || finishing)
        advance = false;
      //(2) 每次执行,都会把 transferIndex 最新的值同步给 nextIndex
      //若 transferIndex小于等于0,则说明原数组中的每个桶位置,都有线程在处理迁移了,
      //于是,需要跳出while循环,并把 i设为 -1,以跳转到④判断在处理的线程是否已经全部完成。
      else if ((nextIndex = transferIndex) <= 0) {
        i = -1;
        advance = false;
      }
      //(3) 第一个线程会先走到这里,确定它的数据迁移范围。(2)处会更新 nextIndex为 transferIndex 的最新值
      //因此第一次 nextIndex=n=16,nextBound代表当次迁移的数据范围下限,减去步长即可,
      //所以,第一次时,nextIndex=16,nextBound=16-2=14。后续,每次都会间隔一个步长。
      else if (U.compareAndSwapInt
               (this, TRANSFERINDEX, nextIndex,
                nextBound = (nextIndex > stride ?
                             nextIndex - stride : 0))) {
        //bound代表当次数据迁移下限
        bound = nextBound;
        //第一次的i为15,因为长度16的数组,最后一个元素的下标为15
        i = nextIndex - 1;
        //表明不需要向前推进,只有当把当前范围内的数据全部迁移完成后,才可以向前推进
        advance = false;
      }
    }
    //(4)
    if (i < 0 || i >= n || i + n >= nextn) {
      int sc;
      //若全部线程迁移完成
      if (finishing) {
        nextTable = null;
        //更新table为新表
        table = nextTab;
        //扩容阈值改为原来数组长度的 3/2 ,即新长度的 3/4,也就是新数组长度的0.75倍
        sizeCtl = (n << 1) - (n >>> 1);
        return;
      }
      //到这,说明当前线程已经完成了自己的所有迁移(无论参与了几次迁移),
      //则把 sc 减1,表明参与扩容的线程数减少 1。
      if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
        //在 addCount 方法最后,我们强调,迁移开始时,会设置 sc=(rs << RESIZE_STAMP_SHIFT) + 2
        //每当有一个线程参与迁移,sc 就会加 1,每当有一个线程完成迁移,sc 就会减 1。
        //因此,这里就是去校验当前 sc 是否和初始值是否相等。相等,则说明全部线程迁移完成。
        if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
          return;
        //只有此处,才会把finishing 设置为true。
        finishing = advance = true;
        //这里非常有意思,会把 i 从 -1 修改为16,
        //目的就是,让 i 再从后向前扫描一遍数组,检查是否所有的桶都已被迁移完成,参看 (6)
        i = n; // recheck before commit
      }
    }
    //(5) 若i的位置元素为空,则说明当前桶的元素已经被迁移完成,就把头结点设置为fwd标志。
    else if ((f = tabAt(tab, i)) == null)
      advance = casTabAt(tab, i, null, fwd);
    //(6) 若当前桶的头结点是 ForwardingNode ,说明迁移完成,则向前推进 
    else if ((fh = f.hash) == MOVED)
      advance = true; // already processed
    //(7) 处理当前桶的数据迁移。
    else {
      synchronized (f) {  //给头结点加锁
        if (tabAt(tab, i) == f) {
          Node<K,V> ln, hn;
          //若hash值大于等于0,则说明是普通链表节点
          if (fh >= 0) {
            int runBit = fh & n;
            //这里是 1.7 的 CHM 的 rehash 方法和 1.8 HashMap的 resize 方法的结合体。
            //会分成两条链表,一条链表和原来的下标相同,另一条链表是原来的下标加数组长度的位置
            //然后找到 lastRun 节点,从它到尾结点整体迁移。
            //lastRun前边的节点则单个迁移,但是需要注意的是,这里是头插法。
            //另外还有一点和1.7不同,1.7 lastRun前边的节点是复制过去的,而这里是直接迁移的,没有复制操作。
            //所以,最后会有两条链表,一条链表从 lastRun到尾结点是正序的,而lastRun之前的元素是倒序的,
            //另外一条链表,从头结点开始就是倒叙的。看下图。
            Node<K,V> lastRun = f;
            for (Node<K,V> p = f.next; p != null; p = p.next) {
              int b = p.hash & n;
              if (b != runBit) {
                runBit = b;
                lastRun = p;
              }
            }
            if (runBit == 0) {
              ln = lastRun;
              hn = null;
            }
            else {
              hn = lastRun;
              ln = null;
            }
            for (Node<K,V> p = f; p != lastRun; p = p.next) {
              int ph = p.hash; K pk = p.key; V pv = p.val;
              if ((ph & n) == 0)
                ln = new Node<K,V>(ph, pk, pv, ln);
              else
                hn = new Node<K,V>(ph, pk, pv, hn);
            }
            setTabAt(nextTab, i, ln);
            setTabAt(nextTab, i + n, hn);
            setTabAt(tab, i, fwd);
            advance = true;
          }
          //树节点
          else if (f instanceof TreeBin) {
            TreeBin<K,V> t = (TreeBin<K,V>)f;
            TreeNode<K,V> lo = null, loTail = null;
            TreeNode<K,V> hi = null, hiTail = null;
            int lc = 0, hc = 0;
            for (Node<K,V> e = t.first; e != null; e = e.next) {
              int h = e.hash;
              TreeNode<K,V> p = new TreeNode<K,V>
                (h, e.key, e.val, null, null);
              if ((h & n) == 0) {
                if ((p.prev = loTail) == null)
                  lo = p;
                else
                  loTail.next = p;
                loTail = p;
                ++lc;
              }
              else {
                if ((p.prev = hiTail) == null)
                  hi = p;
                else
                  hiTail.next = p;
                hiTail = p;
                ++hc;
              }
            }
            ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
            (hc != 0) ? new TreeBin<K,V>(lo) : t;
            hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
            (lc != 0) ? new TreeBin<K,V>(hi) : t;
            setTabAt(nextTab, i, ln);
            setTabAt(nextTab, i + n, hn);
            setTabAt(tab, i, fwd);
            advance = true;
          }
        }
      }
    }
  }
}

代码很长,但我们会逐步击破!首先是ForwardingNode,可以看到只有一个成员变量,它会设置nextTable的头节点为MOVED,用于表示该位置的元素已经迁移完毕

接着就是transfer方法,先是声明了一些变量,初始化stride的值,而且对于第一个进入transfer方法的线程,还会初始化新的数组(nextTab),而后续的线程直接使用此数组。如果回到addCount方法,会看到帮助扩容的线程代码是transfer(tab, nt),但有一处是transfer(tab, null),表明是第一次初始化新表nextTable。

接着是几个重要的变量,fwdadvancefinishing

//创建一个标志类
ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
//是否向前推进的标志
boolean advance = true;
//是否所有线程都全部迁移完成的标志
boolean finishing = false; // to ensure sweep before committing nextTab

接着是一个for循环,是一个死循环,直到所有数据迁移完毕:for (int i = 0, bound = 0;;)。接着跳入while循环,说白了就是领任务。可能是第一次领,也可能是做完了再次跳入while循环领任务。如果无需领任务,或者领完了,都会把advance设置为false,跳出while循环,继续后面的迁移/检验操作

//需要向前推进
while (advance) {
  int nextIndex, nextBound;
  // --i >= bound表示已经超出了边界,或者finishing表示已经全部完成
  // 那么当然就不会继续推进了。一个表示线程已经领好了任务,跳出while开始工作
  // 另一个表示全部数据都迁移完了(可能在领任务中途其他线程全部做完了)
  // 无论是哪一种,都会跳出while循环,设置advance为false
  if (--i >= bound || finishing)
    advance = false;
  // 如果transferIndex已经小于0,说明所有数据都已经有线程在负责迁移
  // 当前线程也无需领任务了,跳出while,等待全部finishing吧
  else if ((nextIndex = transferIndex) <= 0) {
    i = -1;
    advance = false;
  }
  // 如果执行到这里,就是线程要领任务了,用CAS更新了transferIndex
  // 更新成功说明线程领取任务成功,设置好bound和i,跳出while循环
  else if (U.compareAndSwapInt
           (this, TRANSFERINDEX, nextIndex,
            nextBound = (nextIndex > stride ?
                         nextIndex - stride : 0))) {
    //bound代表当次数据迁移下限
    bound = nextBound;
    i = nextIndex - 1;
    advance = false;
  }
}

当一个线程跳出了while循环,有两种情况:

①它刚刚完成了任务,接着跳入while循环领取任务,发现没有任务需要领取了

②它刚刚从while里领取了任务

无论是哪一种情况,此时都应该先判断所有的数据是否都已经迁移完毕(任务全部都完成了)。如果全部完成了,那么就直接return。如果没有,对于情况①,它就会自旋等待,等待其他线程把任务全部做完,或者当某个线程阻塞,这时候它CAS成功将会获取这一份任务。对于情况②,那么它就继续完成自己的任务。判断所有数据迁移完成的代码如下:

if (i < 0 || i >= n || i + n >= nextn) {
  int sc;
  //若全部线程迁移完成
  if (finishing) {
    nextTable = null;
    //更新table为新表
    table = nextTab;
    //扩容阈值改为原来数组长度的 3/2 ,即新长度的 3/4,也就是新数组长度的0.75倍
    sizeCtl = (n << 1) - (n >>> 1);
    return;
  }
  //到这,说明当前线程已经完成了自己的所有迁移(无论参与了几次迁移),
  //则把 sc 减1,表明参与扩容的线程数减少 1。
  if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
    //在 addCount 方法最后,我们强调,迁移开始时,会设置 sc=(rs << RESIZE_STAMP_SHIFT) + 2
    //每当有一个线程参与迁移,sc 就会加 1,每当有一个线程完成迁移,sc 就会减 1。
    //因此,这里就是去校验当前 sc 是否和初始值是否相等。相等,则说明全部线程迁移完成。
    if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
      return;
    //只有此处,才会把finishing 设置为true。
    finishing = advance = true;
    //这里非常有意思,会把 i 从 -1 修改为16,
    //目的就是,让 i 再从后向前扫描一遍数组,检查是否所有的桶都已被迁移完成
    i = n; // recheck before commit
  }
}

里面用finishing判断是否全部线程迁移完成,因为数组已经创建完,数据也迁移完成,所以只有几行简单的赋值代码。虽然finishing依然为false,但只要进入到这里,说明当前线程已经完成了自己的迁移任务,于是把sc减1,之后再判断一次是否全部线程迁移完成,但在此处是用CAS判断的(可能当前线程就是最后一个线程)。如果确实全部都迁移完成了,那么会执行:finishing = advance = true为什么advance也要设置为true?因为后面还会把 i 的值赋值为n,使得i从后往前再扫描一次数组,检查是否所有的元素都迁移完成

那么为什么这里if的判定条件是:if (i < 0 || i >= n || i + n >= nextn),在后面我们会看到,数据迁移需要把原本的数据复制到新数组里,下标就和HashMap的resize逻辑一样,无需再哈希,下标要么是原本的下标,要么是原下标+原数组长度。上面的三种情况,都超出了newTable的范围,都说明线程已经完成了它当前要迁移的任务,而且无需继续advance了,因此它不再需要去领取任务,可以直接去跳出transfer方法。同时判断是否执行完毕,如果执行完毕就和前面所说的一样,赋值,再扫描一次检查等等。如果还没执行完毕,但它的任务也已经结束,无需再做其他事了。

直到这里,我们已经写好了线程如何判断要向前推进(advance),什么时候表明线程全部迁移完成(finishing)。后面就是线程如何进行具体的迁移数据操作,以及如何判断线程已经迁移完成:

// 若i的位置元素为空,则说明当前桶的元素已经被迁移完成,就把头结点设置为fwd标志。
else if ((f = tabAt(tab, i)) == null)
  advance = casTabAt(tab, i, null, fwd);
// 若当前桶的头结点是 ForwardingNode ,说明迁移完成,则向前推进 
else if ((fh = f.hash) == MOVED)
  advance = true; // already processed

如果桶的元素为空,那么就无需迁移了,因为根本就没有数据,但还是要给头节点设置为fwd。如果头节点是fwd,说明迁移完成,向前推进。此时会回到while循环,但如果在while循环里找不到新的任务,此时就会进入到if (i < 0 || i >= n || i + n >= nextn),判断是否所有数据都迁移完成。

把终止条件都考虑好了,剩下的就是具体的迁移数据操作(依然把Tree操作省去了,以后再补充):

else {
  synchronized (f) {			// 给头节点加锁
    if (tabAt(tab, i) == f) {
      Node<K,V> ln, hn;
      if (fh >= 0) {			// 普通链表节点
        int runBit = fh & n;		// 扩容后的第n位的值
        Node<K,V> lastRun = f;
        for (Node<K,V> p = f.next; p != null; p = p.next) {
          int b = p.hash & n;
          if (b != runBit) {
            runBit = b;
            lastRun = p;
          }
        }
        if (runBit == 0) {
          ln = lastRun;
          hn = null;
        }
        else {
          hn = lastRun;
          ln = null;
        }
        for (Node<K,V> p = f; p != lastRun; p = p.next) {
          int ph = p.hash; K pk = p.key; V pv = p.val;
          if ((ph & n) == 0)
            ln = new Node<K,V>(ph, pk, pv, ln);
          else
            hn = new Node<K,V>(ph, pk, pv, hn);
        }
        setTabAt(nextTab, i, ln);
        setTabAt(nextTab, i + n, hn);
        setTabAt(tab, i, fwd);
        advance = true;
      }
      else if (f instanceof TreeBin) {
        // ... 
      }
    }
  }
}

前面已经把多线程环境下的线程进入,推进,终止已经考虑好了,因此这部分代码实际上跟HashMap非常像,只需要在开头给头节点加锁。这里结合了JDK1.7中ConcurrentHashMap的lastRun节点,以及JDK1.8中HashMap的双链表进行迁移。首先依然是找到lastRun节点,然后计算好lastRun节点的runBit,决定到底是在hn链表还是ln链表(high index newTable,low index newTable)。处理完lastRun后半段的子链表,从头节点开始遍历到lastRun节点,这部分和JDK1.8一样,使用复制,同样地根据第n位的值为0还是1,决定加入到ln还是hn。如下图例子:

image-20200905014655267

步骤:先找到lastRun节点,把lastRun后半段子链表添加到ln或者hn链表,接着从头节点开始遍历节点,直到lastRun终止,采用头插法。lastRun部分会是顺序,而其他会是倒序。

Step1:找到lastRun节点,为65。计算得到runBit为0,于是ln链表此时为「65➡️97」。

Step2:从头开始遍历节点,直到lastRun终止。对于节点1,计算得到bit为0,采用头插法,此时ln链表为「1➡️65➡️97」。

Step3:对于节点17,bit为1,于是hn链表为「17」

Step4:对于节点33,bit为0,于是ln链表为「33➡️1➡️65➡️97」

Step5:对于节点49,bit为1,于是hn链表为「49➡️17」

所以,lastRun后面的节点是顺序,而因为lastRun前面的节点采用头插法,因而会是倒序的情况。至此,transfer方法到此结束。

helpTransfer方法

final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
  Node<K,V>[] nextTab; int sc;
  //头结点为 ForwardingNode ,并且新数组已经初始化
  if (tab != null && (f instanceof ForwardingNode) &&
      (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
    int rs = resizeStamp(tab.length);
    while (nextTab == nextTable && table == tab &&
           (sc = sizeCtl) < 0) {
      //若校验标识失败,或者已经扩容完成,或推进下标到头,则退出
      if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
          sc == rs + MAX_RESIZERS || transferIndex <= 0)
        break;
      //当前线程需要帮助迁移,sc值加1
      if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) {
        transfer(tab, nextTab);
        break;
      }
    }
    return nextTab;
  }
  return table;
}

线程执行putVal方法时,当发现map正处于扩容,会调用helpTransfer考虑是否帮助扩容。helpTransfer主要是判定是否需要帮助扩容,如果需要,直接调用transfer方法,反之直接break。如上面的代码逻辑:

①当ConcurrentHashMap尝试插入的时候,发现节点是forward类型,说明数组已经初始化,那么才会去考虑帮忙扩容。如果数组还没初始化,说明是第一个线程,应该调用transfer(tab, null);

②每加入一个线程都会将sizeCtl的低16位加1,同时会校验高16位的标志符。

③扩容最大的帮助线程数位65535,这是低16位的最大值限制。如果线程已经达到最上限,就不要去helpTransfer了。

④当线程确实需要helpTransfer,就CAS修改sc的值,进行加1操作,然后返回

(PS:此处有bug,错误的方式和addCount里的扩容判断一模一样。)

addCount里的扩容判断

回过头来看addCount里的后半段关于扩容的代码。

// x为1,check代表链表上的元素个数
private final void addCount(long x, int check) {
 // ...
 // 执行到这里说明需要扩容
 if (check >= 0) {
  Node<K,V>[] tab, nt; int n, sc;
  //若元素个数达到扩容阈值(即map添加的节点数大于等于sizeCtl
  //且tab不为空,且tab数组长度小于最大容量
  while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
      (n = tab.length) < MAXIMUM_CAPACITY) {
   //这里假设数组长度n就为16,这个方法返回的是一个固定值rs,用于当做一个扩容的校验标识
   int rs = resizeStamp(n);
   //若sc小于0,说明正在扩容
   if (sc < 0) {
		// 此处有bug,我们只需要知道:当前桶数组正在扩容,但当前线程无需帮助扩容
    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
     sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
     transferIndex <= 0)
     break;
    //到这里说明当前线程可以帮助扩容,因此sc值加一,代表扩容的线程数加1
    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
     transfer(tab, nt);
   }
   // 当sc大于0,说明sc代表扩容阈值,因此第一次扩容之前肯定走这个分支,用于初始化新表 nextTable
   // 此时会把sc赋值为(rs << RESIZE_STAMP_SHIFT) + 2,是一个标记值,表示首个帮助扩容的线程
   else if (U.compareAndSwapInt(this, SIZECTL, sc,
           (rs << RESIZE_STAMP_SHIFT) + 2))
    //扩容,第二个参数代表新表,传入null,则说明是第一次初始化新表(nextTable)
    transfer(tab, null);
   s = sumCount();
  }
 }
}

//扩容时的校验标识
static final int resizeStamp(int n) {
 return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}

首先要再次清晰sizeCtl的作用,它虽然只是一个变量,但它承担了5个角色的功能:

①当它为负数,且为-1时,表示正在初始化。

②当它为负数,且为-N时,它的后16位表示有 N - 1个线程在帮助扩容,而高16位表示在帮助哪个容量进行扩容

③当它为0时,这是默认值

④当它为正数时,且桶数组为null,它代表桶数组的初始化大小。

⑤当它为正数时,且桶数组不为null,它代表桶数组的扩容阈值。

对于上面的代码,while (s >= (long)(sc = sizeCtl) && (tab = table) != null && (n = tab.length) < MAXIMUM_CAPACITY),这就是扩容的条件:当容量大于等于sizeCtl,执行到此处的sizeCtl为正数,且桶数组不为null,并且桶数组的长度还没有达到MAX,根据上面的第②或第⑤点,这就是桶数组的扩容阈值,因此此时桶数组的容量大于等于sizeCtl,那么就会进行扩容。

这时候先判断sc到底是大于0(情况⑤),还是小于0(情况②)。如果是小于0,说明现在正处于扩容状态,此处需要判断当前线程是否应该帮助扩容。那么什么时候不需要帮助扩容呢,有三点:

(sc >>> RESIZE_STAMP_SHIFT) != rs

这里我们要先看rs的计算方法resizeStamp方法是如何计算的,含义是什么:

/**
  * Returns the stamp bits for resizing a table of size n.
  * Must be negative when shifted left by RESIZE_STAMP_SHIFT.
  */
static final int resizeStamp(int n) {
  return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}

返回一个关于桶数组长度n的一个stamp bits(戳),也就是根据n唯一对应的一个标记值。numberOfLeadingZeros返回n转换为二进制之后,前面有多少个前缀0,后面是指2的15次方(RESIZE_STAMP_BITS为16)。于是我们得到的值形式会是:

0000 0000 0000 0000 1xxx xxxx xxxx xxxx

Must be negative when shifted left by RESIZE_STAMP_SHIFT,当此数左移16位之后,它必须是一个负数,因为第17位为1,左移之后第一位为1,为负数。至于为什么必须是负数,后面再考虑,我们现在只需要知道,rs是一个对应于n的标记值,也就是n改变的时候,rs也会跟着改变。前面的第②个功能说明了sc的高16位就是这个标记值rs,如果当前sc的高16位与rs不相等,说明在这个并发的过程中,其他线程已经扩容完毕。现在虽然正处于扩容状态,但并不是原本的桶数组长度对应的扩容。比如原本是长度为16扩容为32,某个线程A参与到帮助扩容的任务当中。但它在某一个时刻发现高16位标记值与rs不相等,可能是其他线程已经扩容完毕了,现在是长度32扩容为64,那么此时线程A就能退出当前的“过期”任务了。这就是为什么当 (sc >>> RESIZE_STAMP_SHIFT) != rs,直接break,跳出循环,不再帮助扩容。

② sc == rs + 1 (此处有Bug)

还是前面的第②个功能,-N表示有N - 1个线程在扩容,而此时为1,说明没有线程在扩容。最关键的是,最开始的时候,我们会把sc设置为(rs << RESIZE_STAMP_SHIFT) + 2),表明有1个线程触发了扩容,此时有1个线程在帮助扩容。但现在连触发扩容的线程都已经退出,说明扩容已经结束,那么当前线程自然也是直接break,跳出循环。

③sc == rs + MAX_RESIZER

RESIZER是指帮助扩容的线程数,那么MAX_RESIZER显然就是最大帮助扩容线程数。所以这里的逻辑也呼之欲出:已经达到了最大帮助扩容线程数,当前线程直接退出。

那么如果不是上面3种情况,说明当前线程需要帮助扩容,于是通过CAS把sc的值加1,表示帮助扩容的线程数+1:

if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
  transfer(tab, nt);

如果sc的值为正数,此时就是第⑤条功能,会执行以下逻辑:

// 当sc大于0,说明sc代表扩容阈值,因此第一次扩容之前肯定走这个分支,用于初始化新表 nextTable
// 此时会把sc赋值为(rs << RESIZE_STAMP_SHIFT) + 2,是一个标记值,表示首个帮助扩容的线程
else if (U.compareAndSwapInt(this, SIZECTL, sc,
                             (rs << RESIZE_STAMP_SHIFT) + 2))
  //扩容,第二个参数代表新表,传入null,则说明是第一次初始化新表(nextTable)
  transfer(tab, null);

sc为正数,那么sc为扩容阈值,即第一个线程导致了扩容,此时通过CAS把sc的值初始化为(rs << RESIZE_STAMP_SHIFT) + 2),transfer方法的第二个参数为null也说明了确实是第一次扩容。

上面可能会有一些遗留问题,如下:

Question1:rs是对应于n的一个标记值,但为什么它要Must be negative when shifted left by RESIZE_STAMP_SHIFT?

Ans:sizeCtl为负数的时候表示它正在扩容,而sizeCtl在首次扩容的时候会初始化为(rs << RESIZE_STAMP_SHIFT) + 2)。此时可以使得sizeCtl成为:高16位就是所谓的stamp(戳),是桶数组长度的一个标记值,用于检测当前扩容线程是否对应同一个长度,如果长度已经改变,当前线程就无需继续扩容了(已经扩容完毕,现在其他线程正在执行其他扩容任务!)。而后16位表示有多少个线程在帮助扩容,-2,因此是1个线程帮助扩容。

Question2:所以为什么要初始化为 (rs << RESIZE_STAMP_SHIFT) + 2)?

Ans:准确地说,当数组正在扩容的时候,sizeCtl的值为resizeStamp + 1 + 正在帮助扩容的线程数量。

resizeStamp是一个很大的负值,其高16位为stamp,低16位为0(1xxx xxxx xxxx xxxx 0000 0000 0000 0000)。

因为设定低16位的值为k时,有k - 1个线程在帮助线程。初始的时候就是只有1个线程正在扩容,所以低16位为2!

而且即使线程数到达MAX_RESIZER的时候,sc也不会发生符号改变,因为MAX_RESIZER的最大值为65535,+1之后也只是使得第17位进1。而因为桶的数组长度是2次幂,因而第1位必定是0,移位后第17位也一定是0。因此即使在线程数最大的时候,sc依然是负数。

Question3:sc == rs + 1和 sc == rs + MAX_RESIZER有什么错误?

Ans:虽然ConcurrentHashMap是设计非常精妙的类,但它还是出现了一些小bug。这里的sc为负数,而rs是一个正数,它们是永远不可能相等的。sc是后16位与rs进行比较,这就是这两个情况的错误,它没有进行移位计算。

在JDK12中,此bug得到了修改,rs变量在赋值的时候就先进行移位操作:参考链接

if (check >= 0) {
  Node<K,V>[] tab, nt; int n, sc;
  while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
         (n = tab.length) < MAXIMUM_CAPACITY) {
    int rs = resizeStamp(n) << RESIZE_STAMP_SHIFT;
    if (sc < 0) {
      if (sc == rs + MAX_RESIZERS || sc == rs + 1 ||
          (nt = nextTable) == null || transferIndex <= 0)
        break;
      if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1))
        transfer(tab, nt);
    }
    else if (U.compareAndSetInt(this, SIZECTL, sc, rs + 2))
      transfer(tab, null);
    s = sumCount();
  }
}

同理,可以看一下JDK12里的helpTransfer方法,也是做出了一模一样的修改方案:

/**
  * Helps transfer if a resize is in progress.
  */
final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
  Node<K,V>[] nextTab; int sc;
  if (tab != null && (f instanceof ForwardingNode) &&
      (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
    int rs = resizeStamp(tab.length) << RESIZE_STAMP_SHIFT;	// 移位操作放到赋值这里进行操作了
    while (nextTab == nextTable && table == tab &&
           (sc = sizeCtl) < 0) {
      if (sc == rs + MAX_RESIZERS || sc == rs + 1 ||
          transferIndex <= 0)
        break;
      if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1)) {
        transfer(tab, nextTab);
        break;
      }
    }
    return nextTab;
  }
  return table;
}

总结

ConcurrentHashMap的学习时间线拖得比较长,但收获还是很大。首先对JDK1.7中分段锁的概念有了更深入的了解。而JDK1.8在synchronized优化之后,放弃了分段锁,选用CAS+synchronized。但个中还是加入了很多优化,因此理论虽简单,代码实现却比较复杂。1.7中的分段锁主要是每一个Segment各司其职,只要Hash函数足够好,元素尽量地均匀分布到每一个Segment中,那么就可以近似地认为ConcurrentHashMap的并发度是Segment数组的长度,而编码的关键也是处理好各个Segment之间的锁问题。而1.8没有了分段锁的概念,却加入了sizeCtl变量,一值多用,很好地提高了效率,多个put的线程可以选择帮助扩容也使得效率得到大大的提升,但在计算size的时候还是用到了CounterCells数组,个人觉得这个概念还是和分段锁比较相似。

总结,1.7使用分段锁,起初觉得概念复杂, 但实际上比较简单,代码实现也不难。而1.8使用CAS+synchronized在HashMap的基础上改进,起初觉得概念简单,但代码的实现却比较复杂。ConcurrentHashMap里还有很多值得学习的地方,哪怕是上面介绍的方法也还是有很多值得深究的地方,但目前就暂时到这里了~

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页