ConcurrentHashMap类与 Hashtable
相似,都是线程安全的,但与 HashMap
不同,它不 允许将 null 用作键或值。
可以使用Iterator和Enumeration进行遍历,且不会抛出ConcurrentModificationException。不过,迭代器被设计成每次仅由一个线程使用。
ConcurrentHashMap可以做到读取数据不加锁,并且其内部的结构可以让其在进行写操作的时候能够将锁的粒度保持地尽量地小,不用对整个ConcurrentHashMap加锁。
ConcurrentHashMap的内部结构
ConcurrentHashMap为了提高本身的并发能力,在内部采用了一个叫做Segment的结构,一个Segment其实就是一个类Hash Table的结构,Segment内部维护了一个链表数组,我们用下面这一幅图来看下ConcurrentHashMap的内部结构:
从上面的结构我们可以了解到,ConcurrentHashMap定位一个元素的过程需要进行两次Hash操作,第一次Hash定位到Segment,第二次Hash定位到元素所在的链表的头部,因此,这一种结构的带来的副作用是Hash的过程要比普通的HashMap要长,但是带来的好处是写操作的时候可以只对元素所在的Segment进行加锁即可,不会影响到其他的Segment,这样,在最理想的情况下,ConcurrentHashMap可以最高同时支持Segment数量大小的写操作(刚好这些写操作都非常平均地分布在所有的Segment上),所以,通过这一种结构,ConcurrentHashMap的并发能力可以大大的提高。 默认的Segment数量为16个,可以通知参数初始化时进行设置。
再来看看Segment的结构:
static final class Segment<K,V> extends ReentrantLock implements Serializable {
transient volatile int count;
transient int modCount;
transient int threshold;
transient volatile HashEntry<K,V>[] table;
final float loadFactor;
}
由源码可知,Segment继承了可重入锁ReentrantLock类,Segment里面的成员变量的意义:
- count:Segment中元素的数量
- modCount:对table的大小造成影响的操作的数量(比如put或者remove操作)
- threshold:阈值,Segment里面元素的数量超过这个值依旧就会对Segment进行扩容
- table:链表数组,数组中的每一个元素代表了一个链表的头部
- loadFactor:负载因子,同HashMap的loadFactor意义一样
Segment中的元素是以HashEntry的形式存放在链表数组中的,看一下HashEntry的结构:
static final class HashEntry<K,V> {
final K key;
final int hash;
volatile V value;
final HashEntry<K,V> next;
}
可以看到HashEntry的一个特点,除了value以外,其他的几个变量都是final的,这样做是为了防止链表结构被破坏,出现ConcurrentModification的情况。
看一下ConcurrentHashMap的构造函数:
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS; //最大为65536个Segment
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
segmentShift = 32 - sshift;
segmentMask = ssize - 1;
this.segments = Segment.newArray(ssize);
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = 1;
while (cap < c)
cap <<= 1;
for (int i = 0; i < this.segments.length; ++i)
this.segments[i] = new Segment<K,V>(cap, loadFactor);
}
CurrentHashMap的初始化一共有三个参数,一个initialCapacity,表示初始的容量(默认16),一个loadFactor,表示负载参数(默认0.75),最后一个是concurrentLevel(默认16),代表ConcurrentHashMap内部的Segment的数量,concurrentLevel一经指定,不可改变,后续如果ConcurrentHashMap的元素数量增加导致ConrruentHashMap需要扩容,ConcurrentHashMap不会增加Segment的数量,而只会增加Segment中链表数组的容量大小,这样的好处是扩容过程不需要对整个ConcurrentHashMap做rehash,而只需要对Segment里面的元素做一次rehash就可以了。
整个ConcurrentHashMap的初始化方法还是非常简单的,先是根据concurrentLevel来new出Segment,这里Segment的数量是不大于concurrentLevel的最大的2的指数,就是说Segment的数量永远是2的指数个,这样的好处是方便采用移位操作来进行hash,加快hash的过程。接下来就是根据intialCapacity确定Segment的容量的大小,每一个Segment的容量大小也是2的指数,同样使为了加快hash的过程。
这边需要特别注意一下两个变量,分别是segmentShift和segmentMask,这两个变量在后面将会起到很大的作用,假设构造函数确定了Segment的数量是2的n次方,那么segmentShift就等于32减去n,而segmentMask就等于2的n次方减一。
ConcurrentHashMap的get操作
ConcurrentHashMap的get操作是不用加锁的:
public V get(Object key) {
int hash = hash(key.hashCode());
return segmentFor(hash).get(key, hash);
}
// Segment的get操作
V get(Object key, int hash) {
if (count != 0) { // read-volatile
HashEntry<K,V> e = getFirst(hash);
while (e != null) {
if (e.hash == hash && key.equals(e.key)) {
V v = e.value;
if (v != null)
return v;
return readValueUnderLock(e); // recheck
}
e = e.next;
}
}
return null;
}
count表示Segment中元素的数量,定义为transient volatile int count;
ConcurrentHashMap的put(key, value)、remove(key)等操作与上面类似,也是先找到存储的Segment,在对该Segment进行加锁,再做相应操作。
在ConcurrentMap接口中定义了以下四个方法:
public interface ConcurrentMap<K, V> extends Map<K, V> {
V putIfAbsent(K key, V value);
boolean remove(Object key, Object value);
boolean replace(K key, V oldValue, V newValue);
V replace(K key, V value);
}
它们都要求是原子操作,putIfAbsent相当于:
if (!map.containsKey(key))
return map.put(key, value);
else
return map.get(key);
这样就可以用来多个线程同时进行put操作,但是却只允许一个操作成功时使用。
remove(Object key, Object value)相当于:
if (map.containsKey(key) && map.get(key).equals(value)) {
map.remove(key);
return true;
} else return false;
boolean replace(K key, V oldValue, V newValue)相当于:
if (map.containsKey(key) && map.get(key).equals(oldValue)) {
map.put(key, newValue);
return true;
} else return false;
V replace(K key, V value)相当于:
if (map.containsKey(key)) {
return map.put(key, value);
} else return null;
这四种操作在并发条件下非常有用。下面会有一些应用场景来说明。
上面涉及到的操作都是在单个Segment中进行的,但是ConcurrentHashMap有一些操作是在多个Segment中进行,比如size操作,ConcurrentHashMap的size操作也采用了一种比较巧的方式,来尽量避免对所有的Segment都加锁。
在每一个Segment中的有一个modCount变量,代表的是对Segment中元素的数量造成影响的操作的次数,这个值只增不减,size操作就是遍历了两次Segment,每次记录Segment的modCount值,然后将两次的modCount进行比较,如果相同,则表示期间没有发生过写入操作,就将原先遍历的结果返回,如果不相同,则把这个过程再重复做一次,如果再不相同,则就需要将所有的Segment都锁住 ,然后逐个遍历。
ConcurrentHashMap应用实例
下面使用ConcurrentHashMap模拟统计网站页面的访问量,首先使用putIfAbsent方法来进行操作,如果没有访问过的页面,则把值设为1,此时返回值会为null,如果返回值不是null,则证明该页面被访问过了,此时使用replace方法进行操作。该方法使用了CAS语法,因此要使用while循环方式。
class ConcurrentMapCounter {
private ConcurrentHashMap<String, Integer> map = new ConcurrentHashMap<String, Integer>();
public Integer increment(String s) {
//如果s不存在,则放入key->1,否则执行replace方法
if(!(map.putIfAbsent(s, 1) == null)){
boolean flag = false;
while(!flag) {
flag = map.replace(s, map.get(s), map.get(s) + 1);
}
}
return map.get(s);
}
public Map<String, Integer> getMap() {
return map;
}
}
验证如下:
class Counter extends Thread {
static Random random = new Random();
ConcurrentMapCounter cc;
CountDownLatch latch;
final String[] pages = {"a","b", "c"};
public Counter(ConcurrentMapCounter cc, CountDownLatch latch) {
this.cc = cc;
this.latch = latch;
}
public void run() {
String view = pages[random.nextInt(3)];
int i = cc.increment(view);
System.out.println(Thread.currentThread().getName()
+ "第" + i + "个访问页面:" + view);
latch.countDown();
}
}
public static void main(String[] args) throws InterruptedException {
ConcurrentMapCounter cc = new ConcurrentMapCounter();
int count = 1000;
CountDownLatch latch = new CountDownLatch(count);
ExecutorService service = Executors.newFixedThreadPool(30);
for (int i = 0; i < count; i++) {// 100 个线程
service.execute(new Counter(cc, latch));
}
latch.await();
long result = 0;
Map<String, Integer> map = cc.getMap();
for(Map.Entry<String, Integer> s : map.entrySet()) {
System.out.println("页面:" + s.getKey() + "被访问:" + s.getValue());
result += s.getValue();
}
System.out.println("result:" + result);
service.shutdown();
}