1. 原理结构
原理采用分段锁来实现线程安全,首先将数据分为一段一段的存储,然后给每一段数据配一把锁,当一个线程占用锁访问其中一个段数据时,其他段的数据也能被其他线程访问。
ConcurrentHashMap 是由 Segment 数组结构和 HashEntry 数组结构组成。
Segment 实现了 ReentrantLock,所以 Segment 是一种可重入锁,扮演锁的角色。HashEntry 用于存储键值对数据。
static class Segment<K,V> extends ReentrantLock implements Serializable {
}
一个 ConcurrentHashMap 里包含一个 Segment 数组。Segment 的结构和HashMap类似,是一种数组和链表结构,一个 Segment 包含一个 HashEntry 数组,每个 HashEntry 是一个链表结构的元素,每个 Segment 守护着一个HashEntry数组里的元素,当对 HashEntry 数组的数据进行修改时,必须首先获得对应的 Segment的锁。
2. 源码分析
2.1 属性说明
// 整个ConcurrentHashMap的数组大小
static final int DEFAULT_INITIAL_CAPACITY = 16;
// 默认的负载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
// 默认的并发等级(分为多少段)
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// 最大容积
static final int MAXIMUM_CAPACITY = 1 << 30;
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
static final int RETRIES_BEFORE_LOCK = 2;
// 计算存储Segment[]中的位置(length - 1)
final int segmentMask;
// Segment对象内部偏移量(UNSAFE用)
final int segmentShift;
// Segment数组
final Segment<K,V>[] segments;
transient Set<K> keySet;
transient Set<Map.Entry<K,V>> entrySet;
transient Collection<V> values;
2.2 构造函数
- 通过并发等级得出需要初始化Segment数组的大小(大于并发级别的最小2的次方数)
- 通过数组容量与并发等级得出需要初始化Segment对象中HashEntry数组的大小(大于并发级别的最小2的次方数)
- 初始化Segment数组下标为0处的Segment相关信息,其他位置直接复制这个信息就好
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
// 计算初始化ConcurrentHashMap中Segment[]的大小
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
// 计算初始化Segment对象中HashEntry[]的大小
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
// create segments and segments[0]
// 初始化Segment数组下标为0处的Segment相关信息,其他位置直接复制这个信息就好
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
}
public ConcurrentHashMap(int initialCapacity, float loadFactor) {
this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
}
public ConcurrentHashMap(int initialCapacity) {
this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
public ConcurrentHashMap() {
this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
DEFAULT_INITIAL_CAPACITY),
DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
putAll(m);
}
2.3 put方法
- 在构造方法的时候已经分配了Segment[]内存及Segment[0]处的Segment对象,但是没有初始化其他的Segment对象,所以这里判断并初始化Segment对象
- 取出Segment[0]处的相关信息进行复制
- 判断此处偏移量Segment对象是否为null
- new出Segment对象继续判断此处偏移量Segment对象是否为null
- 循环通过UNSAFE的CAS操作对Segment进行初始化赋值
- Segment中的put方法,通过Segment的可重入锁进行操作,添加与hashmap基本一致
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
// 计算key的hash值
int hash = hash(key);
// 计算数组下标(segmentMask就是数组length-1)
int j = (hash >>> segmentShift) & segmentMask;
// 通过UNSAFE类来判断数组出的Segment对象是否还为初始化
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
// 初始化Segment对象
s = ensureSegment(j);
//
return s.put(key, hash, value, false);
}
初始化Segment对象:
- 取出Segment[0]处的相关信息进行复制
- 判断此处偏移量Segment对象是否为null
- new出Segment对象继续判断此处偏移量Segment对象是否为null
- 循环通过UNSAFE的CAS操作对Segment进行初始化赋值
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
调用Segment中的put方法:
与hashmap基本一直,只是多了tryLock()、unlock()等锁机制。
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
if (node != null)
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
2.4 get方法
- 计算key的hash值
- 得到数组下标
- 取出Segment对象
- 循环此处的链表,找到对应的值
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
// 计算key的hash值
int h = hash(key);
// 得到下标
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
// 取出Segment对象
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
// 循环此处的链表,找到对应的值
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
2.5 remove方法
- 计算key的hash值
- 通过hash值找到对应的Segment对象
- 通过Segment对象的put删除
public V remove(Object key) {
int hash = hash(key);
Segment<K,V> s = segmentForHash(hash);
return s == null ? null : s.remove(key, hash, null);
}
通过hash值获取Segment对象
通过UNSAFE直接操作内存
private Segment<K,V> segmentForHash(int h) {
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
return (Segment<K,V>) UNSAFE.getObjectVolatile(segments, u);
}
调用Segment对象中的remove方法
与hashmap基本一直,只是多了tryLock()、unlock()等锁机制。
final V remove(Object key, int hash, Object value) {
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
V v = e.value;
if (value == null || value == v || value.equals(v)) {
if (pred == null)
setEntryAt(tab, index, next);
else
pred.setNext(next);
++modCount;
--count;
oldValue = v;
}
break;
}
pred = e;
e = next;
}
} finally {
unlock();
}
return oldValue;
}