ConcurrentHashMap是Java中使用非常普遍的一个Map,ConcurrentHashMap较HashMap而言极大提高了并发操作速度,我们知道HashMap是线程不安全的,在多线程环境下容易出现死锁,线程安全的HashTable(包括SynchronizedMap)由于对每次操作都加锁导致效率非常低,我们都知道Map一般都是数组+链表结构(JDK1.8可能为数组+红黑树),ConcurrentHashMap避免了对全局加锁改成了局部加锁操作,这样就极大地提高了并发环境下的操作速度,ConcurrentHashMap在JDK1.7和1.8中的实现非常不同,接下来我们会分别对1.7和1.8中的实现原理做分析。
ConcurrentHashMap(JDK1.7)
在JDK1.7中ConcurrentHashMap采用了数组+Segment+分段锁的方式实现,Segment是ConcurrentHashMap(后续此ConcurrentHashMap都泛指JDK1.7中实现)一个非常重要的部分,Segment继承了ReentrantLock因此拥有了锁的功能,Segment中维护了HashEntry数组table,HashEntry本质是一个K-V存储结构,内部存储了目标对象的Key和Value,HashEntry同时也是一个链式结构内部维护了下一个HashEntry变量next,即Segment是一个数组链表结构,而整个ConcurrentHashMap中维护了一个Segment数组segments,因此ConcurrentHashMap的整体结构如下:
其中HashEntry基层了可重入锁,因此对每个segments中的元素进行操作都会加锁。接下来我们分析下ConcurrentHashMap的主要方法的源码进行分析。
put操作
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
//计算hash值
int hash = hash(key);
//获取segments数组位置
int j = (hash >>> segmentShift) & segmentMask;
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
s = ensureSegment(j);
//插入数据
return s.put(key, hash, value, false);
}
ensureSegment方法的主要功能是找出的segmengs数组位置位于k的Segment,如果不存在将会创建。
private Segment<K,V> ensureSegment(int k) {
//获取当前segmegts数组
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
//CAS获取当前数组是否存在,存在立即返回
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
//实例化一个HashEntry数组
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
//再次获segments中取位置位于k的值
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
//不存在则新建一个Segments并CAS设置位于k处
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
插入数据:
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
//这句代码主要是获取Segment的锁,
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
V oldValue;
try {
//获取HashEntry数组
HashEntry<K,V>[] tab = table;
//定位元素位于HashEntry数组中的位置
int index = (tab.length - 1) & hash;
//获取HashEntry中数组位于index的位置,也就是链表的第一个节点
HashEntry<K,V> first = entryAt(tab, index);
//循环遍历链表
for (HashEntry<K,V> e = first;;) {
//链表头节点不为null
if (e != null) {
K k;
//判断当前节点是否与需要插入的数据是否一样,一样的话修改value直接退出循环
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
//获取锁时没有初始化节点就直接初始化节点
if (node != null)
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
//数量大于阈值且长度小于最大容量时对HashEntry数组进行扩容处理
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
//释放锁
unlock();
}
return oldValue;
}
上面这段代码比较简单,基本就是跟HashMap类似,只是多了个获取锁的过程。这里HashEntry数组数量大于某个阈值时会进行扩容处理,我们看下扩容方法rehash;
private void rehash(HashEntry<K,V> node) {
//获取老的数组
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
//设置新的数组容量为老的数组容量的两倍,JDK中大量使用了移位操作,移位比乘除效率高,值得推荐。
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
//实例化新的HashEntry数组
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
//循环遍历数组
for (int i = 0; i < oldCapacity ; i++) {
//获取老数组中位置处于i的节点,即链表的头节点
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask;
//如果链表只有一个节点直接设置需要插入的值为数组值
if (next == null) // Single node on list
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
//寻找一个lastRun节点,这个节点之后的所有节点都要放在某个数组下面
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
//将lastRun后面的所有节点均防止在lastIds这个位置
newTable[lastIdx] = lastRun;
// Clone remaining nodes
//迁移lastRun之前的节点
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
//采用头插入法将链表插入到头部
int nodeIndex = node.hash & sizeMask; // add the new node
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
get操作
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
get操作就非常简单了,也不需要加速操作,首先定位到segments数组,再定位到HashEntry数组,如果数组位置就是需要获取的值直接返回否则遍历链表。
size操作
size()方法比普通的集合获取要复杂点,因为在获取size的时候可能有其它线程正在增加或删除元素,JDK7实现为首先两次不加锁去获取这个值大小,每次都是将所有segments中值相加,若两次相加的结果一样可以认为这个结果是可信的,若两次相加的结构不一致则进一步对所有segments数组加锁处理。
public int size() {
// Try a few times to get accurate count. On failure due to
// continuous async changes in table, resort to locking.
final Segment<K,V>[] segments = this.segments;
int size;
boolean overflow; // true if size overflows 32 bits
long sum; // sum of modCounts
long last = 0L; // previous sum
int retries = -1; // first iteration isn't retry
try {
for (;;) {
//判断是否需要加锁
if (retries++ == RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
ensureSegment(j).lock(); // force creation
}
sum = 0L;
size = 0;
overflow = false;
//统计每个segments数组中元素大小
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
sum += seg.modCount;
int c = seg.count;
if (c < 0 || (size += c) < 0)
overflow = true;
}
}
//若两次相等则退出循环
if (sum == last)
break;
last = sum;
}
} finally {
if (retries > RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock();
}
}
return overflow ? Integer.MAX_VALUE : size;
}
JDK7中ConcurrentHashMap实现还算比较简单也容易看到,其主要思想是采用分段锁的设计思想,对于并发请求写入只锁住集合的一部分这样极大地提高了高并发环境下的性能。
ConcurrentHashMap(JDK1.8)
JDK8中ConcurrentHashMap参考了JDK8 HashMap的实现,采用了数组+链表+红黑树的实现方式来设计,内部大量采用CAS操作,JDK8中彻底放弃了Segment转而采用的是Node,其设计思想也不再是JDK1.7中的分段锁思想了。下面将看下JDK8中ConcurrentHashMap的结构。由于引入了红黑树,使得ConcurrentHashMap的实现非常复杂,我们都知道,红黑树是一种性能非常好的二叉查找树,其查找性能为O(logN),但是其实现过程也非常复杂,而且可读性也非常差,Doug Lea的思维能力确实不是一般人能比的,早期完全采用链表结构时Map的查找时间复杂度为O(N),JDK8中ConcurrentHashMap在链表的长度大于某个阈值的时候会将链表转换成红黑树进一步提高其查找性能。这里我们主要关系Put、Get以及Size方法。
put方法
public V put(K key, V value) {
return putVal(key, value, false);
}
/** Implementation for put and putIfAbsent */
final V putVal(K key, V value, boolean onlyIfAbsent) {
if (key == null || value == null) throw new NullPointerException();
int hash = spread(key.hashCode());
int binCount = 0;
for (Node<K,V>[] tab = table;;) {
Node<K,V> f; int n, i, fh;
if (tab == null || (n = tab.length) == 0)
tab = initTable();
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
if (casTabAt(tab, i, null,
new Node<K,V>(hash, key, value, null)))
break; // no lock when adding to empty bin
}
else if ((fh = f.hash) == MOVED)
tab = helpTransfer(tab, f);
else {
V oldVal = null;
synchronized (f) {
if (tabAt(tab, i) == f) {
if (fh >= 0) {
binCount = 1;
for (Node<K,V> e = f;; ++binCount) {
K ek;
if (e.hash == hash &&
((ek = e.key) == key ||
(ek != null && key.equals(ek)))) {
oldVal = e.val;
if (!onlyIfAbsent)
e.val = value;
break;
}
Node<K,V> pred = e;
if ((e = e.next) == null) {
pred.next = new Node<K,V>(hash, key,
value, null);
break;
}
}
}
else if (f instanceof TreeBin) {
Node<K,V> p;
binCount = 2;
if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
value)) != null) {
oldVal = p.val;
if (!onlyIfAbsent)
p.val = value;
}
}
}
}
if (binCount != 0) {
if (binCount >= TREEIFY_THRESHOLD)
treeifyBin(tab, i);
if (oldVal != null)
return oldVal;
break;
}
}
}
addCount(1L, binCount);
return null;
}
未完待续