前言
众所周知,HashMap并非线程安全的,这使得我们在实际使用时,尤其是多线程环境下,会面临诸多不便。我们可以通过调用Collections.synchronizedMap方法获取一个线程安全的HashMap,但是查看该方法的源代码发现,具体逻辑就是用synchronized关键字修饰原先HashMap的成员方法。此举虽能达到线程安全的目的,但是访问效率会大打折扣。因此本文将介绍一种线程安全又不失访问效率的集合类——ConcurrentHashMap。
ConcurrentHashMap
伴随着JDK版本的升级,ConcurrentHashMap也经历了很多次变更,但是核心的东西不会变,本文以JDK1.6为基准,JDK1.8中的优化部分会明确指出。
基本结构
ConcurrentHashMap相对于HashMap而言,它是线程安全的,相对于线程安全的HashTable而言,引入了“分段锁”的概念,大大提升了并发性。ConcurrentHashMap的基本结构如下:
//默认的初始化容量
static final int DEFAULT_INITIAL_CAPACITY = 16;
//默认负载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
//默认的并发等级
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
//最大容量
static final int MAXIMUM_CAPACITY = 1 << 30;
//Segment中Table数组最小长度为2
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
//Segment的最大数
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
final int segmentMask;
final int segmentShift;
//segment数组
final ConcurrentHashMap.Segment<K, V>[] segments;
ConcurrentHashMap的一些基础属性跟HashMap基本一致,再来看一下Segment结构:
static final class Segment<K, V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
transient volatile int count;
transient int modCount;
transient int threshold;
transient volatile ConcurrentHashMap.HashEntry<K, V>[] table;
final float loadFactor;
……
}
再来看一下ConcurrentHashMap的初始化过程,ConcurrentHashMap提供了5种构造方法,但最终还是通过调用最基础的一个构造方法,代码如下:
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException(); //参数合法性校验
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
//比如我输入的concurrencyLevel=12,那么sshift = 4,ssize =16,所以sshift是意思就是1左移了几次比concurrencyLevel大,ssize就是那个大于等于concurrencyLevel的最小2的幂次方的数
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1; //ssize = ssize << 1 , ssize = ssize * 2
}
segmentShift = 32 - sshift;
segmentMask = ssize - 1; //segmentMask的二进制是一个全是1的数
this.segments = Segment.newArray(ssize); //segment个数是ssize,也就是上图黄色方块数,默认16个
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = 1;
while (cap < c)
cap <<= 1;
for (int i = 0; i < this.segments.length; ++i)
this.segments[i] = new Segment<K,V>(cap, loadFactor);
//cap是一个2的幂次方的数,默认是1,
//也就是每个segment下都构造了cap大小的table数组
}
Segment(int initialCapacity, float lf) {
loadFactor = lf;
setTable(HashEntry.<K,V>newArray(initialCapacity));//构造了一个initialCapacity大小的table
}
可以看出,ConcurrentHashMap内部包含一个Segment数组,每个Segment数组元素为一个ConcurrentHashMap.HashEntry数组,每个HashEntry数组的元素是一个由HashEntry类型节点组成的链表。
关键方法
get方法
public V get(Object key) {
//双hash,和HashMap一样,也是为了更好的离散化;
//先寻找segment的下标,然后再get操作。
int hash = hash(key.hashCode());
return segmentFor(hash).get(key, hash);
}
final Segment<K,V> segmentFor(int hash) {
//寻找segment的下标;
//前面说了segmentMask是一个2进制全是1的数,
//那么&segmentMask就保证了下标小于等于segmentMask,与HashMap的寻下标相似。
return segments[(hash >>> segmentShift) & segmentMask];
}
V get(Object key, int hash) {
// count是一个volatile变量,count非常巧妙,
//每次put和remove之后的最后一步都要更新count,
//就是为了get的时候不让编译器对代码进行重排序,来保证
if (count != 0) {
//寻找table的下标,也就是链表的表头
HashEntry<K,V> e = getFirst(hash);
while (e != null) {
if (e.hash == hash && key.equals(e.key)) {
V v = e.value;
if (v != null)
return v;
// recheck 加锁读,这个加锁读不用重新计算位置,而是直接拿e的值
return readValueUnderLock(e);
}
e = e.next;
}
}
return null;
}
HashEntry<K,V> getFirst(int hash) {
HashEntry<K,V>[] tab = table;
return tab[hash & (tab.length - 1)];
}
put方法
public V put(K key, V value) {
//明确指定value不能为null
if (value == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
//segmentFor如下,先定位segment下标,然后执行put操作
return segmentFor(hash).put(key, hash, value, false);
}
V put(K key, int hash, V value, boolean onlyIfAbsent) {
//先加锁,可以看到,put操作是在segment里面加锁的,
//也就是每个segment都可以加一把锁
lock();
try {
int c = count;
//判断容量,如果不够了就扩容
if (c++ > threshold)
rehash();
HashEntry<K,V>[] tab = table;
//寻找table的下标
int index = hash & (tab.length - 1);
HashEntry<K,V> first = tab[index];
HashEntry<K,V> e = first;
//遍历单链表,找到key相同的为止,如果没找到,e指向链表尾
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;
V oldValue;
if (e != null) {
//如果有相同的key,那么直接替换
oldValue = e.value;
if (!onlyIfAbsent)
e.value = value;
}
else {
//否则在链表表头插入新的结点
oldValue = null;
++modCount;
tab[index] = new HashEntry<K,V>(key, hash, first, value);
count = c; // write-volatile
}
return oldValue;
} finally {
unlock();
}
}
remove方法
V remove(Object key, int hash, Object value) {
//段内先获得一把锁
lock();
try {
int c = count - 1;
HashEntry<K,V>[] tab = table;
int index = hash & (tab.length - 1);
HashEntry<K,V> first = tab[index];
HashEntry<K,V> e = first;
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;
V oldValue = null;
if (e != null) {
//如果找到该key
V v = e.value;
if (value == null || value.equals(v)) {
oldValue = v;
// All entries following removed node can stay
// in list, but all preceding ones need to be
// cloned.
++modCount;
//newFirst此时为要删除结点的next
HashEntry<K,V> newFirst = e.next;
for (HashEntry<K,V> p = first; p != e; p = p.next)
//从头遍历链表将要删除结点的前面所有结点复制一份插入到newFirst之前
newFirst = new HashEntry<K,V>(p.key,p.hash,newFirst, p.value);
tab[index] = newFirst;
count = c; // write-volatile
}
}
return oldValue;
} finally {
unlock();
}
}
size方法
public int size() {
final Segment<K,V>[] segments = this.segments;
long sum = 0;
long check = 0;
int[] mc = new int[segments.length];
// Try a few times to get accurate count. On failure due to
// continuous async changes in table, resort to locking.
for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) {
check = 0;
sum = 0;
int mcsum = 0;
for (int i = 0; i < segments.length; ++i) {
//循环相加每个段内数据的个数
sum += segments[i].count;
//循环相加每个段内的modCount
mcsum += mc[i] = segments[i].modCount;
}
if (mcsum != 0) {
//如果是0,代表根本没有过数据更改,也就是size是0
for (int i = 0; i < segments.length; ++i) {
//再次循环相加每个段内数据的个数
check += segments[i].count;
if (mc[i] != segments[i].modCount) {
//如果modCount和之前统计的不一致了,check直接赋值-1,重新再来
check = -1; // force retry
break;
}
}
}
if (check == sum)
break;
}
if (check != sum) { // Resort to locking all segments
sum = 0;
//循环获取所有segment的锁
for (int i = 0; i < segments.length; ++i)
segments[i].lock();
//在持有所有段的锁的时候进行count的相加
for (int i = 0; i < segments.length; ++i)
sum += segments[i].count;
//循环释放所有段的锁
for (int i = 0; i < segments.length; ++i)
segments[i].unlock();
}
if (sum > Integer.MAX_VALUE) //return
return Integer.MAX_VALUE;
else
return (int)sum;
}
JDK8中ConcurrentHashMap改动
ConcurrentHashMap在JDK8中进行了巨大改动,它摒弃了Segment(分锁段)的概念,而是启用了一种全新的方式实现,利用CAS算法。它沿用了与它同时期的HashMap版本的思想,底层依然由“数组”+链表+红黑树的方式思想(JDK7与JDK8中HashMap的实现),但是为了做到并发,又增加了很多辅助的类,例如TreeBin,Traverser等对象内部类。
jdk8 中完全重写了concurrentHashmap,代码量从原来的1000多行变成了 6000多 行,实现上也和原来的分段式存储有很大的区别。
主要设计上的变化有以下几点:
- 不采用segment而采用node,锁住node来实现减小锁粒度。
- 设计了MOVED状态 当resize的中过程中 线程2还在put数据,线程2会帮助resize。
- 使用3个CAS操作来确保node的一些操作的原子性,这种方式代替了锁。
- sizeCtl的不同值来代表不同含义,起到了控制的作用。
//hash表初始化或扩容时的一个控制位标识量。
//负数代表正在进行初始化或扩容操作
//-1代表正在初始化
//-N 表示有N-1个线程正在进行扩容操作
//正数或0代表hash表还没有被初始化,这个数值表示初始化或下一次进行扩容的大小
private transient volatile int sizeCtl;
// 以下两个是用来控制扩容的时候 单线程进入的变量
private static int RESIZE_STAMP_BITS = 16;
private static final int RESIZE_STAMP_SHIFT = 32 - RESIZE_STAMP_BITS;
// hash值是-1,表示这是一个forwardNode节点
static final int MOVED = -1;
// hash值是-2 表示这时一个TreeBin节点
static final int TREEBIN = -2;
put方法
public V put(K key, V value) {
return putVal(key, value, false);
}
/** Implementation for put and putIfAbsent */
final V putVal(K key, V value, boolean onlyIfAbsent) {
//key和value均不能为null
if (key == null || value == null)
throw new NullPointerException();
int hash = spread(key.hashCode());
int binCount = 0;
//循环执行
for (Node<K,V>[] tab = table;;) {
Node<K,V> f; int n, i, fh;
//如果table为空,执行初始化操作
if (tab == null || (n = tab.length) == 0)
tab = initTable();
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
//table[i]为空,用CAS在table[i]头结点直接插入,退出插入操作;
//如果CAS失败,则有其他节点已经插入,继续下一步
if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value, null)))
break; // no lock when adding to empty bin
}
else if ((fh = f.hash) == MOVED)
//如果table[i]不为空,且table[i]的hash值为-1,
//则有其他线程在执行扩容操作,帮助他们一起扩容,提高性能
tab = helpTransfer(tab, f);
else {
V oldVal = null;
//只锁住了链表的头结点
synchronized (f) {
if (tabAt(tab, i) == f) {
//fh(table[i])的hash>=0,则此时table[i]为链表结构,找到合适位置插入
if (fh >= 0) {
binCount = 1;
for (Node<K,V> e = f;; ++binCount) {
K ek;
if (e.hash == hash &&
((ek = e.key) == key ||
(ek != null && key.equals(ek)))) {
oldVal = e.val;
if (!onlyIfAbsent)
e.val = value;
break;
}
Node<K,V> pred = e;
if ((e = e.next) == null) {
pred.next = new Node<K,V>(hash, key, value, null);
break;
}
}
}
//fh(table[i])的hash<0,table[i]为红黑树结构,这个过程采用同步内置锁实现并发
else if (f instanceof TreeBin) {
Node<K,V> p;
binCount = 2;
if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,value)) != null) {
oldVal = p.val;
if (!onlyIfAbsent)
p.val = value;
}
}
}
}
//到此时,已将键值对插入到了合适的位置,
//检查链表长度是否超过阈值,若是,则转变为红黑树结构
if (binCount != 0) {
if (binCount >= TREEIFY_THRESHOLD)
treeifyBin(tab, i);
if (oldVal != null)
return oldVal;
break;
}
}
}
addCount(1L, binCount);
return null;
}
get方法
public V get(Object key) {
Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
//计算hash值
int h = spread(key.hashCode());
//根据hash值确定节点位置
if ((tab = table) != null && (n = tab.length) > 0 &&
(e = tabAt(tab, (n - 1) & h)) != null) {
//如果搜索到的节点key与传入的key相同且不为null,直接返回这个节点
if ((eh = e.hash) == h) {
if ((ek = e.key) == key || (ek != null && key.equals(ek)))
return e.val;
}
//如果eh<0 说明这个节点在树上 直接寻找
else if (eh < 0)
return (p = e.find(h, key)) != null ? p.val : null;
//否则遍历链表 找到对应的值并返回
while ((e = e.next) != null) {
if (e.hash == h &&
((ek = e.key) == key || (ek != null && key.equals(ek))))
return e.val;
}
}
return null;
}
总结
本文主要介绍了ConcurrentHashMap的结构、内部实现以及随着JDK版本升级所做的变更内容。总得来说,ConcurrentHashMap是Doug Lea大师的作品,大师之作并非我辈三言两语所能道明的,希望能在实际使用中体会到设计的精妙之处。