前言
HashMap
是一个非常优秀的类,使用也非常频繁。唯一的遗憾就是HashMap
不是线程安全的。
前置阅读:
理解了HashMap
,再来看ConcurrentHashMap
会有事半功倍的效果,因为ConcurrentHashMap
底层数据结构、核心方法几乎和HashMap
一模一样,只是在多线程环境下做了很多保证线程安全的操作。
JDK早期提供了线程安全的HashMap
类,那就是Hashtable
,底层几乎把所有的方法都加上了锁,导致效率太低。JDK1.5开始,JUC包中提供了一个更高效的、线程安全的HashMap
类,那就是ConcurrentHashMap
。
本篇主要讲解JDK1.8中ConcurrentHashMap
的底层结构,实现原理,核心方法等。
ConcurrentHashMap
首先看下ConcurrentHashMap
的类体系结构,整体上把握ConcurrentHashMap
可以看到ConcurrentHashMap
主要实现了Map
和Serializable
接口。
内部结构
想要读懂容器类的源码,必须先了解它的数据结构。所以先看看ConcurrentHashMap
的内部结构,重点关注以下几个中的属性即可
public class ConcurrentHashMap<K,V> extends AbstractMap<K,V>
implements ConcurrentMap<K,V>, Serializable {
private static final long serialVersionUID = 7249069246763182397L;
// 最大容量
private static final int MAXIMUM_CAPACITY = 1 << 30;
// 默认容量
private static final int DEFAULT_CAPACITY = 16;
// 最大可能的数组大小
static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
// 默认的并发级别(不使用,为了兼容之前的版本)
private static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// 默认加载因子
private static final float LOAD_FACTOR = 0.75f;
// 链表转红黑树阈值
static final int TREEIFY_THRESHOLD = 8;
// 红黑树退化成链表的阈值
static final int UNTREEIFY_THRESHOLD = 6;
// (红黑)树化时,table数组最小值
// 至少是4倍的TREEIFY_THRESHOLD
static final int MIN_TREEIFY_CAPACITY = 64;
// 第一次新增元素时初始化,始终是2的幂
transient volatile Node<K,V>[] table;
// 扩容时用,代表扩容后的数组
private transient volatile Node<K,V>[] nextTable;
// 节点hash的特殊值
static final int MOVED = -1; // 转移节点的hash值
static final int TREEBIN = -2; // (红黑)树根节点的hash值
static final int RESERVED = -3; // 临时保留的hash值
static final int HASH_BITS = 0x7fffffff; // 普通节点hash的可用位
// 控制table初始化和扩容的字段
// -1 初始化中
// -n 表示n-1个线程正在扩容中
// 0 使用默认容量进行初始化
// >0 使用多少容量
private transient volatile int sizeCtl;
}
构造方法
ConcurrentHashMap
提供了5个构造方法,主要关注3个
public ConcurrentHashMap(int initialCapacity) {
if (initialCapacity < 0)
// 传入的初始化容量不能小于0
throw new IllegalArgumentException();
// 根据传入的capacity计算合理的capacity
int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
MAXIMUM_CAPACITY :
tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
this.sizeCtl = cap;
}
public ConcurrentHashMap(int initialCapacity, float loadFactor) {
// concurrencyLevel传入了1
this(initialCapacity, loadFactor, 1);
}
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0)
// loadFactor、initialCapacity、concurrencyLevel都不能小于0
throw new IllegalArgumentException();
if (initialCapacity < concurrencyLevel) // Use at least as many bins
initialCapacity = concurrencyLevel; // as estimated threads
long size = (long)(1.0 + (long)initialCapacity / loadFactor);
// 根据传入的capacity和loadFactor计算合理的capacity
int cap = (size >= (long)MAXIMUM_CAPACITY) ?
MAXIMUM_CAPACITY : tableSizeFor((int)size);
this.sizeCtl = cap;
}
构造方法基本只做了参数校验,计算合理的capacity值,并没有初始化数组table
核心方法
对于ConcurrentHashMap
而言,核心方法毫无疑问就是put
和get
。所以先来看看put
方法的整体逻辑。
put方法
put
方法用于往map中添加一个键值对K、V。方法实现如下:
public V put(K key, V value) {
// 调用自身putVal()方法
// 第三个参数传false,表示map中有相同的key时(equals相等),直接覆盖其value值
return putVal(key, value, false);
}
/** Implementation for put and putIfAbsent */
final V putVal(K key, V value, boolean onlyIfAbsent) {
// key、value都不能为null
if (key == null || value == null) throw new NullPointerException();
// 计算hash值
int hash = spread(key.hashCode());
int binCount = 0;
for (Node<K,V>[] tab = table;;) {
// 自旋
Node<K,V> f; int n, i, fh;
if (tab == null || (n = tab.length) == 0)
// 数组尚未初始化,进行初始化
tab = initTable();
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
// 当前槽为null(没有数据)
if (casTabAt(tab, i, null,
new Node<K,V>(hash, key, value, null)))
// CAS的方法把k、v包装成Node节点,放在这个槽上
// 成功后就结束自旋,无需加锁
// 不成功继续自旋
break; // no lock when adding to empty bin
}
else if ((fh = f.hash) == MOVED)
// 当前槽上的节点正在转移(扩容)
tab = helpTransfer(tab, f);
else {
// 当前槽上有值,并且不处于转移状态
V oldVal = null;
synchronized (f) {
// 锁住当前槽
// 因为只锁住了一个槽(链表头节点、红黑树根节点),也就是数组的一项,所以比JDK1.7中锁住一段(分段锁)的效率更高
if (tabAt(tab, i) == f) {
// 当前槽上的节点没有被修改过,double-check
if (fh >= 0) {
// 该槽上是单链表
binCount = 1;
for (Node<K,V> e = f;; ++binCount) {
// 遍历当前槽
K ek;
if (e.hash == hash &&
((ek = e.key) == key ||
(ek != null && key.equals(ek)))) {
// 单链表上找到了相同的key,覆盖其value值
oldVal = e.val;
if (!onlyIfAbsent)
e.val = value;
break;
}
Node<K,V> pred = e;
if ((e = e.next) == null) {
// 【尾插法】把新增节点插入链表,退出自旋
pred.next = new Node<K,V>(hash, key,
value, null);
break;
}
}
}
else if (f instanceof TreeBin) {
// 当前槽上存的是红黑树,按照红黑树的方式插入元素
Node<K,V> p;
binCount = 2;
if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
value)) != null) {
oldVal = p.val;
if (!onlyIfAbsent)
p.val = value;
}
}
}
}
if (binCount != 0) {
if (binCount >= TREEIFY_THRESHOLD)
// 链表需要转换成红黑树
treeifyBin(tab, i);
if (oldVal != null)
return oldVal;
break;
}
}
}
// 检检查是否需要扩容,如果需要就扩容
addCount(1L, binCount);
return null;
}
整个put
方法大致分为以下几步:
1、校验K、V,并计算has值
2、进入自旋,判断table
是否已经初始化,如果否,则进行初始化;如果是,执行3
3、判断当前槽上是否为null,如果是,通过CAS的方式新增节点;如果否,执行4
4、判断当前槽上的节点是否正在转移(扩容过程),如果是,辅助扩容;如果否,执行5
5、锁住当前槽,如果当前槽上是单链表,按照单链表的方式新增节点,如果是红黑树,按照红黑树的方式新增节点
6、判断链表是否需要转换成红黑树,如果是,转换成红黑树
7、新增节点完成后,检测是否需要扩容,如果需要,就扩容
从源码来看,ConcurrentHashMap
中的put
方法和HashMap
的put
方法执行的逻辑相差无几。只是利用了自旋 + CAS、Synchronized等来保证线程安全。
接下来深入put
方法中使用的一些内部方法:initTable
、addCount
等
initTable
initTable
用于初始化table
数组,其实现如下:
private final Node<K,V>[] initTable() {
Node<K,V>[] tab; int sc;
while ((tab = table) == null || tab.length == 0) {
// 自旋
// 外层putVal方法已经判断过这个条件,double-check
if ((sc = sizeCtl) < 0)
// 有其他的线程正在初始table数组
Thread.yield(); // lost initialization race; just spin
else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
// CAS的方式抢到锁
try {
if ((tab = table) == null || tab.length == 0) {
// 再次double-check
// 执行初始
int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
table = tab = nt;
sc = n - (n >>> 2);
}
} finally {
sizeCtl = sc;
}
break;
}
}
return tab;
}
初始化table
数组的核心逻辑只有一行new
操作,但是为了保证线程安全和高效,采用了double-check + 自旋 + CAS的方式,这也是多线程并发编程的常见手段。
addCount
addCount
方法用来检测是否需要扩容,如果需要就扩容。
private final void addCount(long x, int check) {
CounterCell[] as; long b, s;
if ((as = counterCells) != null ||
!U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
CounterCell a; long v; int m;
boolean uncontended = true;
if (as == null || (m = as.length - 1) < 0 ||
(a = as[ThreadLocalRandom.getProbe() & m]) == null ||
!(uncontended =
U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
fullAddCount(x, uncontended);
return;
}
if (check <= 1)
return;
s = sumCount();
}
if (check >= 0) {
Node<K,V>[] tab, nt; int n, sc;
while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
(n = tab.length) < MAXIMUM_CAPACITY) {
int rs = resizeStamp(n);
if (sc < 0) {
if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
transferIndex <= 0)
break;
if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
// 有别的线程正在扩容
transfer(tab, nt);
}
else if (U.compareAndSwapInt(this, SIZECTL, sc,
(rs << RESIZE_STAMP_SHIFT) + 2))
// 没有别的线程正在扩容
transfer(tab, null);
s = sumCount();
}
}
}
对于这个方法而言,就是判断要不要扩容,而真正的扩容方法是transfer
,所以具体看下transfer
方法的实现逻辑
/**
* tab表示扩容前的数组
* nextTab表示扩容后的新数组(如果为null,表示并没有别的线程在扩容)
*/
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
int n = tab.length, stride;
if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
stride = MIN_TRANSFER_STRIDE; // subdivide range
if (nextTab == null) { // initiating
try {
// 初始化nextTab,大小为原数组的2倍
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
nextTab = nt;
} catch (Throwable ex) { // try to cope with OOME
sizeCtl = Integer.MAX_VALUE;
return;
}
nextTable = nextTab;
transferIndex = n;
}
// 获取新数组的长度
int nextn = nextTab.length;
// 如果元素组槽上是转移节点,表示该槽上的节点正在转移
ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
boolean advance = true;
boolean finishing = false; // to ensure sweep before committing nextTab
for (int i = 0, bound = 0;;) {
// 自旋
Node<K,V> f; int fh;
while (advance) {
int nextIndex, nextBound;
if (--i >= bound || finishing)
advance = false;
else if ((nextIndex = transferIndex) <= 0) {
// 拷贝已经完成
i = -1;
advance = false;
}
else if (U.compareAndSwapInt
(this, TRANSFERINDEX, nextIndex,
nextBound = (nextIndex > stride ?
nextIndex - stride : 0))) {
bound = nextBound;
i = nextIndex - 1;
advance = false;
}
}
if (i < 0 || i >= n || i + n >= nextn) {
// 拷贝结束
int sc;
if (finishing) {
nextTable = null;
table = nextTab;
sizeCtl = (n << 1) - (n >>> 1);
return;
}
if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
return;
finishing = advance = true;
i = n; // recheck before commit
}
}
else if ((f = tabAt(tab, i)) == null)
advance = casTabAt(tab, i, null, fwd);
else if ((fh = f.hash) == MOVED)
advance = true; // already processed
else {
synchronized (f) {
// 加锁,进行节点拷贝
if (tabAt(tab, i) == f) {
// 低位链表和高位链表(前文HashMap中讲过)
Node<K,V> ln, hn;
if (fh >= 0) {
int runBit = fh & n;
Node<K,V> lastRun = f;
for (Node<K,V> p = f.next; p != null; p = p.next) {
int b = p.hash & n;
if (b != runBit) {
runBit = b;
lastRun = p;
}
}
if (runBit == 0) {
ln = lastRun;
hn = null;
}
else {
hn = lastRun;
ln = null;
}
for (Node<K,V> p = f; p != lastRun; p = p.next) {
// 循环拷贝链表节点
// 原数组中的链表会被分成高、低位两个链表,放在新数组中不同的槽
int ph = p.hash; K pk = p.key; V pv = p.val;
if ((ph & n) == 0)
ln = new Node<K,V>(ph, pk, pv, ln);
else
hn = new Node<K,V>(ph, pk, pv, hn);
}
// 链表设置到新数组
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
// 旧数组上设置转移节点
// 其他线程发下槽上是转移节点后就会等待
setTabAt(tab, i, fwd);
advance = true;
}
else if (f instanceof TreeBin) {
// 拷贝红黑树的节点
TreeBin<K,V> t = (TreeBin<K,V>)f;
TreeNode<K,V> lo = null, loTail = null;
TreeNode<K,V> hi = null, hiTail = null;
int lc = 0, hc = 0;
for (Node<K,V> e = t.first; e != null; e = e.next) {
int h = e.hash;
TreeNode<K,V> p = new TreeNode<K,V>
(h, e.key, e.val, null, null);
if ((h & n) == 0) {
if ((p.prev = loTail) == null)
lo = p;
else
loTail.next = p;
loTail = p;
++lc;
}
else {
if ((p.prev = hiTail) == null)
hi = p;
else
hiTail.next = p;
hiTail = p;
++hc;
}
}
ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
(hc != 0) ? new TreeBin<K,V>(lo) : t;
hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
(lc != 0) ? new TreeBin<K,V>(hi) : t;
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;
}
}
}
}
}
}
可以看到ConcurrentHashMap
中的扩容也是采用了自旋 + CAS + Synchronized来保证线程安全的,除此之外,还添加了转移节点,表示该槽上的节点正在被转移,此时别的线程不要往这个槽写数据。
get方法
get
方法用于从map中根据key
取value
。其实现相比于put
方法要简单得多
public V get(Object key) {
Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
// 根据key计算hash值
int h = spread(key.hashCode());
if ((tab = table) != null && (n = tab.length) > 0 &&
(e = tabAt(tab, (n - 1) & h)) != null) {
// table不为空,且当前槽上有数据
if ((eh = e.hash) == h) {
if ((ek = e.key) == key || (ek != null && key.equals(ek)))
// 槽上第一个节点就是要取的节点,直接返回value
return e.val;
}
else if (eh < 0)
// 槽上的第一个节点红黑树的根节点或者转移节点,调用其find方法查找
return (p = e.find(h, key)) != null ? p.val : null;
while ((e = e.next) != null) {
// 槽上是单链表,遍历单查找
if (e.hash == h &&
((ek = e.key) == key || (ek != null && key.equals(ek))))
// 找到需要的节点,返回value
return e.val;
}
}
// table没有初始
// 当前槽上没有数据
// 红黑树/转移节点/单链表上未找到
return null;
}
size方法
size
方法用于返回map的节点个数,在HashMap
中非常简单,因为定义了一个变量来维护size
,但是ConcurrentHashMap
并没有定义这样的变量,先来看下其size
方法的实现
public int size() {
// 调用内部sumCount方法
long n = sumCount();
return ((n < 0L) ? 0 :
(n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
(int)n);
}
final long sumCount() {
CounterCell[] as = counterCells; CounterCell a;
long sum = baseCount;
if (as != null) {
for (int i = 0; i < as.length; ++i) {
if ((a = as[i]) != null)
sum += a.value;
}
}
return sum;
}
可以看到最后返回size
的值就是baseCount
的值 + counterCells
数组中的所有值之和。
counterCells
数组中存的实际上就是table
数组中每个槽上的节点个数。
baseCount
相当于counterCells
的优化,在没有竞争的时候使用。
实际上也就是分段求和,再汇总的思想。
看到方法内部并没有加锁,说明size
方法返回的并不是一个准确值,而是一个近似值,因为在汇总的过程中,有可能map中新增或者删除了元素。
与JDK1.7的区别
一张图就可以直观的感受到ConcurrentHashMap
在JDK1.7和JDK1.8的区别
对JDK1.7中的ConcurrentHashMap
而言
内部主要是一个Segment
数组,而数组的每一项又是一个HashEntry
数组,元素都存在HashEntry
数组里。因为每次锁定的是Segment
对象,也就是整个HashEntry
数组,所以又叫分段锁。
对JDK1.8中的ConcurrentHashMap
而言
舍弃了分段锁的实现方式,元素都存在Node
数组中,每次锁住的是一个Node
对象,而不是某一段数组,所以支持的写的并发度更高。
再者它引入了红黑树,在hash冲突严重时,读操作的效率更高。
这两点便是JDK1.8对ConcurrentHashMap
所做的主要优化。
总结
ConcurrentHashMap
类的实现,可以说是并发容器中,经典中的经典。深入的理解这个类,需要数据结构的基础,以及并发编程的基础。