ConcurrentHashMap源码分析(JDK1.7)

一、JDK1.7ConcurrentHashMap介绍

JDK1.7中ConcurrentHashMap底层是Segment数组,每一个Segment对象中包含一个HashEntry数组,保存的元素会封装成HashEntry对象,当遇到Hash冲突时,会形成链表
在这里插入图片描述
Segment继承ReentrantLock,当需要控制线程安全时,对单独的Segment进行加锁,即分段锁
几个默认值

  • DEFAULT_INITIAL_CAPACITY 默认初始化容量 16
  • MAXIMUM_CAPACITY 最大容量 2的30次幂
  • DEFAULT_LOAD_FACTOR 默认的负载因子 0.75
  • DEFAULT_CONCURRENCY_LEVEL 并发等级 16

二:构造器

	// 空参构造,调用重载的构造方法
    public ConcurrentHashMap() {
    	//调用本类的带参构造
    	//DEFAULT_INITIAL_CAPACITY = 16
    	//DEFAULT_LOAD_FACTOR = 0.75f
    	//int DEFAULT_CONCURRENCY_LEVEL = 16
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }
    // 指定数组初始化长度
	public ConcurrentHashMap(int initialCapacity) {
        this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }
    // 指定数组初始化长度与默认负载因子
	public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
    }
	// 所有参数全部指定
	public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        // 做一些数据判断
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        int sshift = 0; // 移动的位数,用于后续计算高4位的值
        int ssize = 1; // Segment数组的大小
        while (ssize < concurrencyLevel) { 
            ++sshift;
            // 保证Segment大小为2的次幂,如指定concurrencyLevel 为15则 ssize为16,指定concurrencyLevel为17 则ssize为32
            ssize <<= 1;
        }
        // 这两个值用于后面计算Segment[]的角标
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        // 计算每个Segment中存储的元素个数
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        // 每个Segment最少存储2个HashEntry对象
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        // 保证Segment中存储的HashEntry个数是2的次幂
        while (cap < c)
            cap <<= 1;
        // create segments and segments[0]
        // 创建一个Segment对象,作为模板,后续创建Segment对象时,属性值直接复用,不再重新计算
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        // 创建出底层的Segment数组
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        // 使用Unsafe类,将创建的Segment对象放入数组下标为0的位置
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }
   
    // 根据已有的Map去构造
	public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
                      DEFAULT_INITIAL_CAPACITY),
             DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
        putAll(m);
    }

默认情况:在new一个ConcurrentHashMap时,底层创建了一个长度为16的Segment数组,每个Segment中保存了一个长度为2的HshEntry数组,之后保存的数据都是在HashEntry中。ConcurrentHashMap中默认元素长度为32,而不是16。

三:put方法

/**
* 保存/修改方法,可能会涉及到扩容
* 不允许存储Null键和Null值
*/
public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)
    	// 值不可以为Null
        throw new NullPointerException();
    // 基于key,计算hash值,计算hash值时,key若为Null会报空指针异常
    int hash = hash(key);
    // 取高位计算要保存的Segment数组的下标(HashEntry数组用的低位计算下标)
    int j = (hash >>> segmentShift) & segmentMask;
    // 如果该位置没有Segment对象,则创建一个Segment对象
    if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
         (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
        s = ensureSegment(j);
    // 调用Segmetn的put方法实现元素添加
    return s.put(key, hash, value, false);
}
/**
* 返回对应索引k位置的Segment对象,没有则使用索引位置0的Segment对象创建一个
* 这里没有使用锁机制,仅依靠Unsafe类进行操作
*/
private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    // k索引位置对应的偏移量
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    // 如果数组ss索引k的位置对应的Segment对象为null
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
    	// 使用构造器初始化时生成的Segment对象作为原型
        Segment<K,V> proto = ss[0]; // use segment 0 as prototype
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        // 再次确认在此过程中其他线程没有将该Segment对象创建出来
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // recheck
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                   // 使用CAS将新建的Segment对象放置在数组的该位置中,如果放置成功,则break返回,如果失败再判断其余线程有没有放置成功,循环此操作
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}
/**
* Segment对象的put方法,加锁,然后保存值
*/
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
	// 尝试加锁,获取到锁则继续,获取不到则调用scanAndLockForPut方法获取锁
    HashEntry<K,V> node = tryLock() ? null :
        scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table;
        // 获取HashEntry数组的下标,这里用低位
        int index = (tab.length - 1) & hash;
        // 获取数组tab对应下标index的值,有可能为0
        HashEntry<K,V> first = entryAt(tab, index);
        for (HashEntry<K,V> e = first;;) {
        	// 获取的元素不为空
            if (e != null) {
                K k;
                // 如果重复,覆盖
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    if (!onlyIfAbsent) {
                        e.value = value;
                        ++modCount;
                    }
                    break;
                }
                // 不重复,获取链表的下一个元素
                e = e.next;
            }
            else {
                if (node != null)
                	// 如果node在等待锁的时候已经创建出来,则采用头插法直接插入即可
                    node.setNext(first);
                else
                	// 创建一个新的HashEntry,next属性指向first(头插法)
                    node = new HashEntry<K,V>(hash, key, value, first);
                int c = count + 1;
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                	// 如果超过扩容阈值,进行扩容,再将node插入
                    rehash(node);
                else 
                	// 不扩容,直接放到Entry数组中
                    setEntryAt(tab, index, node);
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
    	// 释放锁
        unlock();
    }
    return oldValue;
}
/**
* 使用Unsafe将元素e放到数组tab的下标i位置处
*/
static final <K,V> void setEntryAt(HashEntry<K,V>[] tab, int i,
                               HashEntry<K,V> e) {
    UNSAFE.putOrderedObject(tab, ((long)i << TSHIFT) + TBASE, e);
}

总结一下put流程:

  1. 首先判断不允许null值和null键
  2. 计算hash值,通过hash值的高位计算出要保存的Segment数组下标
  3. 判断Segment数组该下标位置是否有值,如果没有值,使用Unsafe类并自旋且使用Segment数组下标0位置的原型去创建一个Segment对象,创建对象时要判断其他线程有没有将此对象创建出来,如果获取到了其他线程创建的Segment对象,则将该对象返回
  4. 调用Segment对象的put方法去保存值
  5. 保存值时,首先尝试获得该Segment对象的锁,获取不到则自旋不停尝试获取锁,在此过程中会将要保存的值封装成一个HashEntry对象,节省之后的时间,如果自旋重试一定次数(64/1)后,不再自旋,强制加锁,阻塞等待锁
  6. 获取锁以后,先计算该值应该保存在Segment的HashEntry数组的哪一个位置,采用低位进行计算
  7. 获取HashEntry数组对应下标的值,记为first
  8. 遍历该位置的链表,如果first为null或者遍历完都不重复,则插入(头插法),插入时,如果在自旋获取锁时已经创建好HashEntry对象,直接使用,未创建好则去创建。如果在遍历链表过程中,找到重复的元素,进行替换。
  9. 在插入过程中,先判断是否需要扩容,需要扩容则进行扩容,扩容完将该元素插入到新数组中,不需要扩容则将此元素放到头结点。
  10. 最终释放锁

四:resize()扩容

扩容只针对加锁的单个Segment的HashEntry数组进行扩容,HashEntry数组大小变成2倍,但仅局限于当前的Segment,其余Segment中的HashEntry数组大小不会变化,另外,Segment数组大小在初始化以后也不会发生变化。

/**
* 扩容方法,扩容2倍,将
*/
private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    // 扩容2倍
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    // 创建一个新的HashEntry数组
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    int sizeMask = newCapacity - 1;
    // 进行数据迁移,将旧数组上面的
    for (int i = 0; i < oldCapacity ; i++) {
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 计算新数组的下标
            int idx = e.hash & sizeMask;
            if (next == null)   //  Single node on list
            	// 该位置只有一个元素,直接放过去
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                for (HashEntry<K,V> last = next;
                     last != null;
                     last = last.next) {
                    // lastRun机制,对数据转移进行了一点优化,获取该位置最后计算出在新数组位置连续的元素,整体转移到新数组中去
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                newTable[lastIdx] = lastRun;
                // Clone remaining nodes
                // 转移剩余的元素,从原数组的头结点开始,到原数组的lastRun结点
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    // 将新的元素采用头插法插入到新数组中
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    // 将新数组赋值给Segment对象的HashEntry数组属性
    table = newTable;
}

resize()方法总结:

  1. 因为是在put方法中,已经获取到lock,所以可以确保线程安全
  2. 扩容时首先生成一个新的数组,大小为原数组的2倍
  3. 将原数组的元素迁移到新数组中,遍历原HashEntry数组,获取到每一个位置的HashEntry链表
  4. 计算原数据应该转移到新数组的下标,如果原HashEntry链表只有一个元素,直接转移到新链表
  5. 如果有多个,采用lastRun机制,先获取最后几个在新数组连续的元素,先转移过去
  6. 之后遍历剩余的链表,根据原HashEntry元素生成新的HashEntry对象,并采用头插法放到新数组中
  7. 数据转移完之后,将新的元素再计算下标之后,头插法放到新数组中
  8. 将新数组更新到Segment对象中

五:get方法

public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    // 如果key为null,在这一步会抛异常
    int h = hash(key);
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    // 获取Segment数组对应下标的Segment对象以及该Segment对象的HashEntry数组,并且两者都不为空
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
        // 获取HashEntry数组对应下标的链表,遍历获取值
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                 (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

get方法总结:

  1. 先计算获取该key对应哪一个Segment对象(key若为null,则空指针)
  2. 再计算该key对应Segment对象的哪一个HashEntry链表
  3. 遍历链表获取值
  4. 这几步有任一步骤获取不到值,返回null。

六:size方法

/**
* 获取ConcurrentHashMap的大小
* 先计算两遍,看一下两遍计算的值是否相等,若相等则返回
* 若两次计算结果不一致,则再计算一遍并比较
* 若再次计算以后结果与上一次计算结果还不一致,则加锁计算
*/
public int size() {
    // Try a few times to get accurate count. On failure due to
    // continuous async changes in table, resort to locking.
    final Segment<K,V>[] segments = this.segments;
    int size;
    boolean overflow; // true if size overflows 32 bits
    long sum;         // sum of modCounts
    long last = 0L;   // previous sum
    int retries = -1; // first iteration isn't retry
    try {
        for (;;) {
        	// 先比较再自加,RETRIES_BEFORE_LOCK为2
        	// -1 0 1 共循环三次
            if (retries++ == RETRIES_BEFORE_LOCK) {
            	// 对每一个Segment进行加锁
                for (int j = 0; j < segments.length; ++j)
                    ensureSegment(j).lock(); // force creation
            }
            sum = 0L;
            size = 0;
            overflow = false;
            for (int j = 0; j < segments.length; ++j) {
                Segment<K,V> seg = segmentAt(segments, j);
                if (seg != null) {
                    sum += seg.modCount;
                    int c = seg.count;
                    // 计算每一个segment的count之和
                    if (c < 0 || (size += c) < 0)
                        overflow = true;
                }
            }
            // 如果连续两次计算值一致,则返回
            if (sum == last)
                break;
            // 修改次数之和不一致,将此次计算结果记为last,再计算一遍
            last = sum;
        }
    } finally {
        if (retries > RETRIES_BEFORE_LOCK) {
        	// 如果已经加锁,则释放锁
            for (int j = 0; j < segments.length; ++j)
                segmentAt(segments, j).unlock();
        }
    }
    return overflow ? Integer.MAX_VALUE : size;
}

size方法总结:

  1. 因为在计算ConcurrentHashMap的size时,有可能会并发插入/删除数据
  2. 首先不加锁,计算两次所有Segment对象的modCount之和,判断是否一致,如果不一致再循环计算一次
  3. 如果再循环一次计算的和与之前计算还不一致,就加锁进行计算,这里加锁前最多计算三次

结语:还在学习过程中,做一个学习记录,如有不对的地方,欢迎批评指正。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值