本文基于JDK1.7分析
先大概讲下ConcurrentHashMap的原理:
1.ConcurrentHashMap内部分成了很多段,每段叫一个Segment,每个Segment里都是一个类似HashMap的结构
2.每次写操作都只对其中一个Segment进行加锁操作,所以线程安全且性能高
ConcurrentHashMap内部主要结构
final Segment<K,V>[] segments;
static final class Segment<K,V> extends ReentrantLock implements Serializable {
transient volatile HashEntry<K,V>[] table;
transient int count;
transient int modCount;
transient int threshold;
final float loadFactor;
}
可以看到ConcurrentHashMap不是直接操作Entry数组,而是操作segment数组,而每个Segment里都有一个Entry数组,相当于一个ConcurrentHashMap包含了多个HashMap
看下构造方法
// concurrencyLevel是并发级别,会根据concurrencyLevel来决定Segment数组有多大
public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
int sshift = 0;
int ssize = 1;//Segment数组的大小
//计算出一个大于等于concurrencyLevel的2的n次方的值
//例如:concurrencyLevel为10、11、12、13等取得的值都是16
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY;//Entry数组的大小
while (cap < c)//确保cap为大于等于c的2的n次方
cap <<= 1;
// 初始化第一个元素
Segment<K, V> s0 = new Segment<K, V>(loadFactor, (int) (cap * loadFactor), (HashEntry<K, V>[]) new HashEntry[cap]);
Segment<K, V>[] ss = (Segment<K, V>[]) new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
}
接下来看下put和get方法
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
//计算出该元素应该放的segment,这里用的算法能确保均匀的放在每个segment上
int hash = hash(key);
int j = (hash >>> segmentShift) & segmentMask;
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
s = ensureSegment(j);
//调用对应segment进行put操作
return s.put(key, hash, value, false);
}
Segment的put方法
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
//加锁成功则为空,这里用到的是AQS同步器的方法,ReentrantLock的非公平锁
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);//加锁失败进入的方法
V oldValue;
try {
HashEntry<K,V>[] tab = table;
//计算出key在table中的位置
int index = (tab.length - 1) & hash;
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {//第一个不为空则代表 有冲突 或者 key相同的情况
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {//如果key已经存在,那么当onlyIfAbsent为false则进行覆盖
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
if (node != null)
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
//如果数量大于临界值 且 数字大还没超过最大值,则进行扩容,容量为原来的2倍
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else//否则放到对应位置上
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
put操作的第一步是进行一个加锁操作,如果获取锁成功则node为null,程序继续执行
如果获取失败那么将scanAndLockForPut方法进行处理
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // 重复请求的次数
while (!tryLock()) {//循环请求锁
HashEntry<K,V> f;
if (retries < 0) {
if (e == null) {
if (node == null) // //不存在该节点,则创建一个新节点
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
//如果重复请求的次数到达上限,直接调用ReentrantLock的lock方法,成功则继续运行,否则阻塞
else if (++retries > MAX_SCAN_RETRIES) {
lock();
break;
}
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f; // 如果发现循环请求的过程中,头节点发生变化,重置头结点和retries
retries = -1;
}
}
return node;
}
如果一开始请求锁失败了,不是直接阻塞,而是循环重试请求直到到达一次次数才进行阻塞
get方法
public V get(Object key) {
Segment<K,V> s;
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
get方法虽然计算位置有点复杂,但是总的思路就是找到segment,再找到数组上的位置,进行遍历查找就好了,这里就不说了
其他方法就不在分析,大概原理差不多