ConcurrentHashMap是一个线程安全的并发容器,内部主要基于哈希结构。ConcurrentHashMap在JDK1.7到JDK1.8时内部结构有较大的变化,本文主要对JDK1.8中ConcurrentHashMap的源代码进行讨论。
本篇文章在ConcurrentHashMap源代码的基础上,对主要的查找、插入和扩容等操作进行分析,并增加了详细的代码注释,帮助理解ConcurrentHashMap的源代码。
1. ConcurrentHashMap基本结构
ConcurrentHashMap内部结构主要基于哈希结构,内部元素存储主要基于一个数组table,table的每一个位置代表一个hash位置。在每个位置上,hash值相同的结点构成链表或红黑树。
1.1 链表结点 Node
链表结点 Node 的定义如下:包含key,value,next以及该结点的hash值。
static class Node<K,V> implements Map.Entry<K,V> {
final int hash;
final K key;
volatile V val;
volatile Node<K,V> next;
Node(int hash, K key, V val, Node<K,V> next) {
this.hash = hash;
this.key = key;
this.val = val;
this.next = next;
}
}
1.2 内部的常量与变量
ConcurrentHashMap源码中一部分常用的常量与变量。table即为ConcurrentHashMap中的内部数组。
// 最大容量为 2 ^ 30, 作者解释Java数组的寻址空间有限制,最高两位用来控制hash过程,所以只能用30位。
private static final int MAXIMUM_CAPACITY = 1 << 30;
// 默认容量为 16
private static final int DEFAULT_CAPACITY = 16;
// 默认负载因子 0.75
private static final float LOAD_FACTOR = 0.75f;
// ConcurrentHashMap的内部数组 table
transient volatile Node<K, V>[] table;
// 仅在扩容时用到的新内部数组
private transient volatile Node<K, V>[] nextTable;
// 用来表示内部数组初始化以及扩容时的状态
// 负值代表内部数组正在初始化或扩容
// 当为-1时,表示正在初始化
// 小于-1时,表示正在扩容,其值为 -(1 + 扩容线程数量)
// 0 为初始状态
// 大于0时,其值表示下次触发扩容的容量大小 (0.75 * capacity)
private transient volatile int sizeCtl;
// 扩容时的指示下标
private transient volatile int transferIndex;
1.3 ConcurrentHashMap构造函数
public ConcurrentHashMap(int initialCapacity) {
// 非法参数处理
if (initialCapacity < 0)
throw new IllegalArgumentException();
// 计算数组的大小,使其成为刚好超过initialCapacity的 2的整幂次。
// 例如:initialCapacity = 3, 则将数组大小设定为 4(2^2 且 4 > 3);
// initialCapacity = 13, 则将数组大小设定为16(2^4 且 16 > 13)。
// 当initialCapacity > MAXIMUM_CAPACITY / 2 时,
// 将 initialCapacity 设定为 MAXIMUM_CAPACITY。(为了满足上述的规则)
// 当initialCapacity <= MAXIMUM_CAPACITY / 2时,使用tableSizeFor()计算出数组大小。
int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
MAXIMUM_CAPACITY :
// 用位运算计算出刚好超过initialCapacity的2的整幂次。
tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
this.sizeCtl = cap;
}
// 还有以下几个构造方法,loadFactor是负载因子,
// 默认值为 0.75, 实际空间 = initialCapacity / loadFactor
public ConcurrentHashMap(Map<? extends K, ? extends V> m)
public ConcurrentHashMap(int initialCapacity, float loadFactor)
public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel)
1.4 hash方法以及获取数组位置的方法
ConcurrentHashMap中根据元素key的hashcode和spread方法得到元素的hash值,然后基于内部数组大小为2的幂次的特点,使用 (n - 1) & hash 即可快速得到元素在数组中的位置。
// 根据 key 的 hashcode,然后spread(),得到 hash 值。
int hash = spread(key.hashCode());
// n 是内部数组长度,前面提到内部数组的长度均为2的幂次,
// 所以 (n - 1) 的低位应该全部为 1。
// 例如 n = 16 (10000),n - 1 = 15 (01111)。
// 此时再与 hash 进行与(&)运算,相当于进行了取模运算,
// 所得到的结果就是该元素在内部数组中的位置。
tabAt(tab, (n - 1) & hash)
// spread方法使 hash 值分布更为平均
static final int HASH_BITS = 0x7fffffff;
static final int spread(int h) {
// 让 h 逻辑右移 16 位。目的是保证高位也参与了运算,
// 如果不右移,取模运算时只有低位参与运算。
// 并且使得到的 hash 更为平均。
// HASH_BITS保证最高位为 0,确保得到一个正数。
return (h ^ (h >>> 16)) & HASH_BITS;
}
1.5 Node结点的 hash 值
Node结点的数据结构定义中除了包含key,value,next,还包含一个 hash 值。通常情况下,Node结点的hash值即为该结点key的hash值,在以下三种特殊情况下为负值:
// Node的 hash 值为 -1,表示ForwardingNode,仅在扩容时使用
static final int MOVED = -1;
// Node的 hash 值为 -2,表示红黑树根节点
static final int TREEBIN = -2;
// Node的 hash 值为 -3,表示ReservationNode,在compute() 和 computeIfAbsent() 中使用。
static final int RESERVED = -3;
2. get
get方法,从ConcurrentHashMap中获取指定key的value。主要操作为包含计算目标key的hash值,取得该hash值对应的数组位置,在数组位置上进行查找。
// table为 ConcurrentHashMap 内部数组。懒加载策略,在构造时不会初始化,在插入时才会初始化。
transient volatile Node<K,V>[] table;
public V get(Object key) {
Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
// 计算目标 key 的 hash 值。
int h = spread(key.hashCode());
// 判断内部数组是否已经初始化,数组大小是否大于0,
// 以及目标key元素对应的hash位置是否存在元素。
if ((tab = table) != null && (n = tab.length) > 0 &&
(e = tabAt(tab, (n - 1) & h)) != null) {
// 如果 hash 值大于0,且与目标 key的 hash 值相等
if ((eh = e.hash) == h) {
// 如果对应hash位置的第一个元素就是待查找目标key,则返回val。
if ((ek = e.key) == key || (ek != null && key.equals(ek)))
return e.val;
}
// hash 值小于0,说明是特殊结点,调用find方法查找
// find方法在每一类特殊节点中进行了重写来支持不同特殊结点的查找。
else if (eh < 0)
return (p = e.find(h, key)) != null ? p.val : null;
// 沿着hash位置的链表继续向后遍历寻找目标Key。
while ((e = e.next) != null) {
if (e.hash == h &&
((ek = e.key) == key || (ek != null && key.equals(ek))))
return e.val;
}
}
return null;
}
3. put
put方法将元素插入ConcurrentHashMap,在插入时,使用synchronized对当前数组位置的头结点进行加锁,确保线程安全。put操作完成后会判断是否将链表转换成红黑树,然后对元素总计数进行修改。
final V putVal(K key, V value, boolean onlyIfAbsent) {
// 异常输入处理
if (key == null || value == null) throw new NullPointerException();
// 计算 hash 值。
int hash = spread(key.hashCode());
int binCount = 0;
for (Node<K,V>[] tab = table;;) {
Node<K,V> f; int n, i, fh;
// HashMap内部数组懒加载,如果还没有数组则初始化数组。
if (tab == null || (n = tab.length) == 0)
tab = initTable();
// 如果当前 hash 位置没有元素,则直接 CAS 添加一个元素,
// 并设定 key 和 value。
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
if (casTabAt(tab, i, null,
new Node<K,V>(hash, key, value, null)))
break; // no lock when adding to empty bin
}
// 如果当前结点的 hash 值为 -1(ForwardingNode),
// 则说明正在扩容,当前线程也加入协助扩容。
else if ((fh = f.hash) == MOVED)
tab = helpTransfer(tab, f);
// 当前 hash 位置元素不为空,需要遍历链表结点寻找目标key值是否已存在
else {
V oldVal = null;
// 对头节点加锁,防止其他线程在同一位置进行修改
synchronized (f) {
// 确定头结点依然是刚刚获取的头结点 f
if (tabAt(tab, i) == f) {
// 结点 f 的 hash 值大于 0,属于普通结点
if (fh >= 0) {
// 当前 hash 位置的元素数量计数。
binCount = 1;
// 遍历链表,同时 binCount 对链表中元素个数进行计数。
for (Node<K,V> e = f;; ++binCount) {
K ek;
// 如果找到了目标 key值结点,且 onlyIfAbsent 为 false,
// 则直接使用新 value 覆盖原有的 value。
if (e.hash == hash &&
((ek = e.key) == key ||
(ek != null && key.equals(ek)))) {
oldVal = e.val;
if (!onlyIfAbsent)
e.val = value;
break;
}
// 如果到了链表尾部还没有发现目标key值,则在链表尾部新增加一个结点,
// 并填充 key 和 value。
Node<K,V> pred = e;
if ((e = e.next) == null) {
pred.next = new Node<K,V>(hash, key,
value, null);
break;
}
}
}
// 如果头结点是红黑树结点,则在红黑树上进行上述操作。
else if (f instanceof TreeBin) {
Node<K,V> p;
binCount = 2;
if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
value)) != null) {
oldVal = p.val;
if (!onlyIfAbsent)
p.val = value;
}
}
}
}
if (binCount != 0) {
// 如果遍历链表过程中 binCount 计数超过 8(TREEIFY_THRESHOLD为 8),
// 则将当前hash位置的存储结构由链表变成红黑树
if (binCount >= TREEIFY_THRESHOLD)
treeifyBin(tab, i);
if (oldVal != null)
return oldVal;
break;
}
}
}
// 增加元素计数
addCount(1L, binCount);
return null;
}
4. addCount
addCount修改元素计数,在这里会检查是否需要扩容。
private final void addCount(long x, int check) {
// CounterCell为计数桶,ConcurrentHashMap会根据是否存在竞争
// 首先尝试 CAS 增加元素数量,如果失败则使用计数桶 CounterCell。
CounterCell[] as; long b, s;
if ((as = counterCells) != null ||
!U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
CounterCell a; long v; int m;
boolean uncontended = true;
if (as == null || (m = as.length - 1) < 0 ||
(a = as[ThreadLocalRandom.getProbe() & m]) == null ||
!(uncontended =
U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
fullAddCount(x, uncontended);
return;
}
if (check <= 1)
return;
// 计数桶有多个位置,计算所有位置上的总和
// 得到实际的元素总数
s = sumCount();
}
// 一个hash位置上的元素数量大于等于0,检查是否需要扩容
if (check >= 0) {
Node<K,V>[] tab, nt; int n, sc;
// 触发扩容的条件:元素总数 >= sizeCtl,并且不超过 MAXIMUM_CAPACITY
while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
(n = tab.length) < MAXIMUM_CAPACITY) {
// 生成扩容时的 Stamp,其值与内部数组的长度有关
// 所以在每一次扩容前,它的值是不变的
// ConcurrentHashMapp依靠这个stamp与sizeCtl来确定同一次扩容时的线程数量
// 以下操作则是在尝试扩容
int rs = resizeStamp(n);
// sc < 0,说明当前有线程正在扩容
if (sc < 0) {
if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
transferIndex <= 0)
break;
if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
// 扩容逻辑
transfer(tab, nt);
}
// sc >= 0, 说明当前没有线程在扩容
else if (U.compareAndSwapInt(this, SIZECTL, sc,
(rs << RESIZE_STAMP_SHIFT) + 2))
transfer(tab, null);
s = sumCount();
}
}
}
5. 扩容transfer
ConcurrentHashMap的扩容的主要过程包括创建新数组,重新计算元素在新数组中的位置,将元素迁移到新数组中。同时ConcurrentHashMap允许多个线程同时参加扩容的过程,提高了扩容的效率。
5.1 扩容前的准备工作
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
int n = tab.length, stride;
// NCPU 为 cpu 核数,通过判断确定扩容时的每个线程负责的区间步长 stride。
if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
stride = MIN_TRANSFER_STRIDE; // subdivide range
// nextTab为扩容后的新内部数组,如果为 null,则创建。
if (nextTab == null) { // initiating
try {
@SuppressWarnings("unchecked")
// n << 1 表示左移一位,即乘 2。新数组大小为原数组的两倍。
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
nextTab = nt;
} catch (Throwable ex) {
// 处理 OOM 异常
sizeCtl = Integer.MAX_VALUE;
return;
}
nextTable = nextTab;
// transferIndex的初值为原数组长度 n。
// 实际扩容时将根据 transferIndex 的值倒序从数组最后向前进行数据迁移
transferIndex = n;
}
int nextn = nextTab.length;
// ForwardingNode为一个特殊结点,内部包含一个nextTable的引用。
// 用这个特殊结点对已经完成迁移的位置进行标记。
ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
// 是否继续推进
boolean advance = true;
// 是否全部完成
boolean finishing = false; // to ensure sweep before committing nextTab
5.2 扩容过程
多个线程同时进行扩容,整个内部数组 table 被划分为多个区间,每个线程领取一部分区间,负责对该区间内部的hash位置进行迁移。前半部分为分配区间操作和终止条件,后半部分为实际的迁移操作。
for (int i = 0, bound = 0;;) {
Node<K,V> f; int fh;
// 分配区间操作和终止条件
while (advance) {
int nextIndex, nextBound;
// i 是当前正在进行迁移的数组位置,bound是区间下界
// 如果当前线程仍在扩容中或所有线程已经全部完成,
// 则将advance设定成false,跳出while循环
if (--i >= bound || finishing)
advance = false;
// 当前线程获取到的区间上界 nextIndex 如果小于 0,
// 则说明所有区间已经全部分配完成。
else if ((nextIndex = transferIndex) <= 0) {
// 将 i 设定为-1,使其进入下面 if (i < 0 || i >= n || i + n >= nextn)的检查
i = -1;
advance = false;
}
// CAS修改 transferIndex,相当于将区间(transferIndex - stride,transferIndex)
// 分配给当前线程进行扩容时的迁移
else if (U.compareAndSwapInt
(this, TRANSFERINDEX, nextIndex,
nextBound = (nextIndex > stride ?
nextIndex - stride : 0))) {
// bound是区间下界
bound = nextBound;
// i 是当前正在进行迁移的数组位置,
// 从区间最后一个位置开始,向前循环
i = nextIndex - 1;
advance = false;
}
}
数据迁移过程:
// i < 0 代表当前线程处理的是最后一段区间,并且已经完成,
// 或者所有区间全部分配完成,上面的操作将 i 设定为-1。
if (i < 0 || i >= n || i + n >= nextn) {
int sc;
// 如果已经全部完成迁移
if (finishing) {
// 将nextTable设定为null
nextTable = null;
// 内部数组table指向新的数组
table = nextTab;
// sizeCtl 调整为新数组大小(原数组两倍)的 0.75(负载因子)
sizeCtl = (n << 1) - (n >>> 1);
return;
}
// 一个线程完成所负责的区间后,会先进入上面的while循环,尝试继续领取区间
// 当所有区间已经全部分配后,i被设定为-1,进入这里
// 对sizeCtl值进行操作,减去当前线程
// 然后判断该线程是不是在扩容的最后一个线程
if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
// 如果当前线程不是正在扩容的最后一个线程则 return
if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
return;
// 所有扩容线程都完成工作了,开始收尾工作(finishing=true)
finishing = advance = true;
// 再次进入循环,对每一个位置进行检查,确保全部完成迁移
i = n; // recheck before commit
}
}
// 数组位置 i 全部迁移完成,则在 i 位置放置一个 ForwardingNode。
// advance为true,进入到分配区间的那个while循环, --i,继续推进下一个位置
else if ((f = tabAt(tab, i)) == null)
advance = casTabAt(tab, i, null, fwd);
// 如果数组位置 i 已经存在 ForwardingNode,则直接跳过
// advance为true,进入到分配区间的那个while循环, --i,继续推进下一个位置
else if ((fh = f.hash) == MOVED)
advance = true; // already processed
else {
// 头结点 f 加锁
synchronized (f) {
if (tabAt(tab, i) == f) {
Node<K,V> ln, hn;
// hash 大于 0,为普通结点
if (fh >= 0) {
// 新数组为原数组大小的两倍,且数组大小均为 2 的幂次,
// 所以一个 hash 位置只可能在新数组的**原位置**或新数组**原位置+n**的位置。
// 决定因素是hash值在代表数组大小 n 的那一位的是 0 还是 1。
// 例如:hash值为 xxxx110101,原数组为16(010000),新数组为32(100000)
// hash算法 hash & (n - 1): 在原数组结果为 110101 & 001111 = 000101
// 在新数组结果为 110101 & 011111 = 010101
// 决定因素取决于第五位, 即 hash & n 的结果
// runBit就是在计算这一位
int runBit = fh & n;
Node<K,V> lastRun = f;
// 一直循环整个链表,每次 runBit 变化,就记录该结点和新的 runBit
// 相当于最后找到了队尾一段 runBit 相同的子链表的头结点,
// 然后将这整段 runBit 相同的子链表直接移动到新数组。
for (Node<K,V> p = f.next; p != null; p = p.next) {
int b = p.hash & n;
if (b != runBit) {
runBit = b;
lastRun = p;
}
}
// 如果 runBit为 0,将最后队尾一段 runBit 相同的子链表加入ln。
if (runBit == 0) {
ln = lastRun;
hn = null;
}
// 如果 runBit为 1,将最后队尾一段 runBit 相同的子链表加入hn。
else {
hn = lastRun;
ln = null;
}
// 从头开始一个一个结点遍历链表,将结点用头插法加入 ln 或 hn 链表。
for (Node<K,V> p = f; p != lastRun; p = p.next) {
int ph = p.hash; K pk = p.key; V pv = p.val;
if ((ph & n) == 0)
ln = new Node<K,V>(ph, pk, pv, ln);
else
hn = new Node<K,V>(ph, pk, pv, hn);
}
// ln链表放入新数组的**原位置**
setTabAt(nextTab, i, ln);
// hn链表放入新数组的**原位置 + n**的位置
setTabAt(nextTab, i + n, hn);
// 在原数组设定 ForwardingNode,表示该位置已经完成迁移
setTabAt(tab, i, fwd);
advance = true;
}
// 红黑树处理逻辑
else if (f instanceof TreeBin) {
TreeBin<K,V> t = (TreeBin<K,V>)f;
TreeNode<K,V> lo = null, loTail = null;
TreeNode<K,V> hi = null, hiTail = null;
int lc = 0, hc = 0;
for (Node<K,V> e = t.first; e != null; e = e.next) {
int h = e.hash;
TreeNode<K,V> p = new TreeNode<K,V>
(h, e.key, e.val, null, null);
if ((h & n) == 0) {
if ((p.prev = loTail) == null)
lo = p;
else
loTail.next = p;
loTail = p;
++lc;
}
else {
if ((p.prev = hiTail) == null)
hi = p;
else
hiTail.next = p;
hiTail = p;
++hc;
}
}
ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
(hc != 0) ? new TreeBin<K,V>(lo) : t;
hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
(lc != 0) ? new TreeBin<K,V>(hi) : t;
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;
}
}
}
}
}
}
6. remove
remove方法主要是通过 replaceNode() 来实现的。作者基于replaceNode()方法实现了多个remove和replace方法,当replaceNode()方法的第三个入参 cv 为null时进行remove,否则进行replace。
public V remove(Object key) {
return replaceNode(key, null, null);
}
final V replaceNode(Object key, V value, Object cv) {
int hash = spread(key.hashCode());
for (Node<K,V>[] tab = table;;) {
Node<K,V> f; int n, i, fh;
if (tab == null || (n = tab.length) == 0 ||
(f = tabAt(tab, i = (n - 1) & hash)) == null)
break;
// hash 值为 -1,正在扩容。
else if ((fh = f.hash) == MOVED)
// 当前线程加入扩容,协助完成扩容
tab = helpTransfer(tab, f);
else {
V oldVal = null;
boolean validated = false;
synchronized (f) {
if (tabAt(tab, i) == f) {
// hash 值大于 0, 普通结点。
if (fh >= 0) {
validated = true;
// 遍历链表
for (Node<K,V> e = f, pred = null;;) {
K ek;
// 找到了目标 key 值的结点
if (e.hash == hash &&
((ek = e.key) == key ||
(ek != null && key.equals(ek)))) {
V ev = e.val;
// remove时方法入参 cv 为 null
if (cv == null || cv == ev ||
(ev != null && cv.equals(ev))) {
oldVal = ev;
// remove时方法入参 value 为 null
if (value != null)
e.val = value;
// 链表删除操作,前驱结点指向下一个结点
else if (pred != null)
pred.next = e.next;
// 前驱结点为null,直接将下一个结点设定为头结点
else
setTabAt(tab, i, e.next);
}
break;
}
pred = e;
if ((e = e.next) == null)
break;
}
}
// 红黑树的遍历
else if (f instanceof TreeBin) {
validated = true;
TreeBin<K,V> t = (TreeBin<K,V>)f;
TreeNode<K,V> r, p;
if ((r = t.root) != null &&
(p = r.findTreeNode(hash, key, null)) != null) {
V pv = p.val;
if (cv == null || cv == pv ||
(pv != null && cv.equals(pv))) {
oldVal = pv;
if (value != null)
p.val = value;
else if (t.removeTreeNode(p))
setTabAt(tab, i, untreeify(t.first));
}
}
}
}
}
if (validated) {
if (oldVal != null) {
if (value == null)
// 计数减一
addCount(-1L, -1);
return oldVal;
}
break;
}
}
}
return null;
}