ConcurrentHashMap 底层实现原理

一、线程安全集合概述

        线程安全集合类可以分为三大类:遗留的线程安全集合、使用 Collections 创建的以及 Concurrent 包下的。

1.1 遗留的线程安全集合

        遗留的线程安全集合如 Hashtable Vector。里面的方法实现都是使用 synchronized 修饰的。

1.2 使用 Collections 创建

        使用 Collections 装饰的线程安全集合,本质上也是在 方法里面加上 synchronized 进行修饰。通过传入一个不安全的集合,返回一个安全的集合,如下:

Collections.synchronizedCollection()
Collections.synchronizedList()
Collections.synchronizedMap()
Collections.synchronizedSet()
Collections.synchronizedNavigableMap()
Collections.synchronizedNavigableSet()
Collections.synchronizedSortedMap()
Collections.synchronizedSortedSet()

1.3 concurrent

        重点介绍 java.util.concurrent.* 下的线程安全集合类,可以发现它们有规律,里面包含三类关键词:BlockingCopyOnWriteConcurrent

1.3.1 Blocking

        Blocking 大部分实现基于锁,并提供用来阻塞的方法。

1.3.2 CopyOnWrite

        CopyOnWrite 之类容器修改开销相对较重,采用的是修改时拷贝的方式来避免多线程访问时的并发安全,适用于读多写少的场景。

1.3.3 Concurrent

        Concurrent 类型的容器,内部很多操作使用 cas 优化,一般可以提供较高吞吐量。但是存在弱一致性问题。弱一致性问题主要体现在下面的三个方面。

        1、遍历时弱一致性,例如,当利用迭代器遍历时,如果容器发生修改,迭代器仍然可以继续进行遍历,这时内容是旧的。

        2、求大小弱一致性,size 操作未必是 100% 准确。

        3、读取弱一致性,有可能你读取的同时其他线程已经把它的值给改了。

注意:

        1、对于非安全容器来讲,遍历时如果发生了修改,会使用 fail-fast 机制也就是让遍历立刻失败,抛出 ConcurrentModificationException,不再继续遍历。

        2、而对于线程安全的集合,遍历时如果发生了修改,它不会失败。它会让你的遍历继续运行,这种叫做 fail-safe 机制。

二、ConcurrentHashMap

2.1 入门实例

        接下来进行单词计数的一个练习,即统计每个单词出现的数量。

2.1.1 创建数据

        首先使用代码生成测试数据,需要手动在 src 同级目录下创建 tmp 文件夹,代码如下所示:

public class ProductData {
    // 准备 26 个英文字母,每个字母当成一个单词
    static final String Path = "abcedfghijklmnopqrstuvwxyz";

    // 为了使将来的结果可预测,每个字母循环 200 次,最终 26 个字母每个都会出现 200 次
    public static void main(String[] args) {
        int length = Path.length();
        int count = 200;
        List<String> list = new ArrayList<>(length * count);
        for(int i=0;i<length;i++){
            char ch = Path.charAt(i);
            for(int j=0;j<count;j++){
                list.add(String.valueOf(ch));
            }
        }
        // 打乱集合里面元素的顺序
        Collections.shuffle(list);
        // 打乱之后存入到 26 个 txt 文件中
        for(int i=0;i<26;i++){
            try(PrintWriter out = new PrintWriter(new OutputStreamWriter(new FileOutputStream("tmp/"+(i+1)+".txt")))) {
                String collect = list.subList(i * count, (i + 1) * count).stream()
                        .collect(Collectors.joining("\n"));
                out.print(collect);
            }catch (IOException e){

            }
        }
    }
}
// 最终我们统计的时候,每个代码出现的次数也都得是 200 次

        执行上面的代码,最终可以生成 26 txt 文件,如下

2.1.2 统计数据

        接下来我们使用多线程来统计数据,即一个线程读一个文件,代码如下:

package com.springbootrabbitmq.controller;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

public class TestWordCount {

    public static void main(String[] args) {
        demo(() -> new HashMap<String, Integer>(),
                // 进行计数
                (map, words) -> {
                    for (String word : words) {
                        // 先判断集合类有没有这个单词
                        Integer counter = map.get(word);
                        // 如果为 null,那就从 1 开始计数;如果不为 null,那就计数器 +1
                        int newValue = counter == null ? 1 : counter + 1;
                        // 放入到 map 当中
                        map.put(word, newValue);
                    }
                }
        );
    }
    // demo 方法接收两个参数,一个是保存结果的集合类,另一个是带两个参数的消费器
    private static <V> void demo(Supplier<Map<String,V>> supplier,
                                 BiConsumer<Map<String,V>,List<String>> consumer) {
        Map<String, V> counterMap = supplier.get();
        List<Thread> ts = new ArrayList<>();
        // 循环创建 26 个线程
        for (int i = 1; i <= 26; i++) {
            int idx = i;
            Thread thread = new Thread(() -> {
                // 每个线程根据文件名称去读取文件,将读取的结果放入到 list 中
                List<String> words = readFromFile(idx);
                consumer.accept(counterMap, words);
            });
            ts.add(thread);
        }
        ts.forEach(t->t.start());
        ts.forEach(t-> {
            try {
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        System.out.println(counterMap);
    }

    // 读取文件的方法
    public static List<String> readFromFile(int i) {
        ArrayList<String> words = new ArrayList<>();
        try (BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream("tmp/" + i +".txt")))) {
            while(true) {
                String word = in.readLine();
                if(word == null) {
                    break;
                }
                words.add(word);
            }
            return words;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

        运行结果如下,可以发现统计的结果都不是 200,出现了线程安全问题,这就是没有使用线程安全保护导致的计数不准

         那么该如何修改才能够让计数准确呢?使用 ConcurrentHashMap 来测试下,如下:

public static void main(String[] args) {
	demo(() -> new ConcurrentHashMap<String, Integer>(),
			// 进行计数
			(map, words) -> {
				for (String word : words) {
					// 先判断集合类有没有这个单词
					Integer counter = map.get(word);
					// 如果为 null,那就从 1 开始计数;如果不为 null,那就计数器 +1
					int newValue = counter == null ? 1 : counter + 1;
					// 放入到 map 当中
					map.put(word, newValue);
				}
			}
	);
}

        运行结果如下,我们发现还是出现了线程安全问题。

        ConcurrentHashMap 作为线程安全的集合,它里面的每个方法都可以认为是一个原子的,如果只调用一个方法是可以保证线程安全的,但是我们在使用的时候不止调用了一个方法,我们先调用了 get() 方法,又做了一些运算,还调用了 put() 方法,get put 方法都是原子的,但是它俩加起来就不是原子的了。

        第一种解决方式,使用 synchronized 关键字,代码如下所示

public static void main(String[] args) {
	demo(() -> new ConcurrentHashMap<String, Integer>(),
			// 进行计数
			(map, words) -> {
				for (String word : words) {
					synchronized(map){
						// 先判断集合类有没有这个单词
						Integer counter = map.get(word);
						// 如果为 null,那就从 1 开始计数;如果不为 null,那就计数器 +1
						int newValue = counter == null ? 1 : counter + 1;
						// 放入到 map 当中
						map.put(word, newValue);
					}
				}
			}
	);
}

        可以发现打印是正常的,但是毕竟 synchronized 不是细粒度的锁,并发度上还是有一些欠缺的。

        第二种解决方式,使用 computeIfAbsent() 方法,代码如下

public static void main(String[] args) {
	demo(() -> new ConcurrentHashMap<String, LongAdder>(),
			// 进行计数
			(map, words) -> {
				for (String word : words) {
					// 如果缺少一个 key,则计算生成一个值,然后将 key 和 value 放入 map
					LongAdder value = map.computeIfAbsent(word, (key) -> new LongAdder());
					// 执行累加
					value.increment();
				}
			}
	);
}

        可以发现,打印结果也是没有问题的。 

2.2 HashMap 并发死链

        在 jdk1.7 的环境下,在多线程的环境下使用 hashmap,在扩容时就会出现并发死链的问题。这个问题的根本原因是在多线程的环境下使用了非线程安全的 map 集合。

        在 jdk1.8 的环境下,虽然将扩容算法做了调整,不再将元素加入链表头,而是保持与扩容前一样的顺序,但仍不意味着能够在多线程环境下安全扩容,还会出现其他问题,比如扩容丢失数据等问题。

2.3 重要属性和内部类

        ConcurrentHashMap 内部的属性和内部类,如下所示,第一个属性是 sizeCtl,默认是 0,当哈希表的数组初始化的时候(懒惰初始化),它变为 -1,当数组发生扩容时也是一个负数,当初始化或扩容完成之后,这个 sizeCtl 就是下一次扩容时的阈值大小,即容量的四分之三。

        第二个属性是一个静态内部类 Node,哈希表里面的链表结构就是靠它来链起来的。

        第三个属性是 Node 类型的数组 table,就是内部维护的哈希表。

        第四个属性是发生扩容时的数组 nextTable。他们都用 volatile 修饰了,所以效率会很高。

// 默认为 0
// 当初始化时, 为 -1
// 当扩容时, 为 -(1 + 扩容线程数)
// 当初始化或扩容完成后,为 下一次的扩容的阈值大小
private transient volatile int sizeCtl;

// 整个 ConcurrentHashMap 就是一个 Node[]
static class Node<K,V> implements Map.Entry<K,V> {}

// hash 表
transient volatile Node<K,V>[] table;

// 扩容时的 新 hash 表
private transient volatile Node<K,V>[] nextTable;

// 扩容时如果某个 bin 迁移完毕, 用 ForwardingNode 作为旧 table bin 的头结点
static final class ForwardingNode<K,V> extends Node<K,V> {}

// 用在 compute 以及 computeIfAbsent 时, 用来占位, 计算完成后替换为普通 Node
static final class ReservationNode<K,V> extends Node<K,V> {}

// 作为 treebin 的头节点, 存储 root 和 first
static final class TreeBin<K,V> extends Node<K,V> {}

// 作为 treebin 的节点, 存储 parent, left, right
static final class TreeNode<K,V> extends Node<K,V> {}

        假设下面的 concurrentHashMap 要发生扩容了,当然了,得需要容量达到四分之三才可以扩容,不要在意我只有 4 个元素,我这里只是比喻。

        扩容之后就会创建一个新的数组,把上面的 node 节点一个一个的搬迁过去。由于我们是一个线程安全的 hashtable,需要防止其他线程对我的搬迁工作产生影响,那该如何操作呢?

        首先它是从后往前一个一个的下标进行处理,假设先处理第 15 个下标,发现处理完了,此时就会往它下面加一个头节点 ForwardingNode,当其他线程发现这个节点是 ForwardingNode 时,就表示这个节点已经处理过了,就不会对当前链表进行任何操作。

        等到处理到 4 了,此时就会进行搬迁工作,如下图:

         如果在搬迁的过程中,有其他线程来 get(),如果是 ForwardingNode 则去新的 table 中去找 key

        链表长度达到 8 之后,此时会触发扩容,扩容会使数组的长度扩张一倍,而链表里面的元素会重新计算哈希值,可以有效的减少链表的长度,一般都会减少一半,当数组长度达到 64 的时候,此时就不用扩容了,而是采用链表转红黑树的方式,TreeBin 就是红黑树的头节点,而 TreeNode 就是红黑树的子节点。

2.4 重要方法

        ConcurrentHashMap 内部的重要方法,如下所示:

        其中 tabAt() 是获取哈希表中的第 i node 节点。

        casTabAt() 方法是通过 cas 的方式修改哈希表中第 i 个元素的值,c 为旧值,v 为新值。

        setTabAt() 方法是直接修改哈希表中第 i 个元素的值,v 为新值。

// 获取 Node[] 中第 i 个 Node
static final <K,V> Node<K,V> tabAt(Node<K,V>[] tab, int i)
 
// cas 修改 Node[] 中第 i 个 Node 的值, c 为旧值, v 为新值
static final <K,V> boolean casTabAt(Node<K,V>[] tab, int i, Node<K,V> c, Node<K,V> v)
 
// 直接修改 Node[] 中第 i 个 Node 的值, v 为新值
static final <K,V> void setTabAt(Node<K,V>[] tab, int i, Node<K,V> v)

2.5 构造器分析

        以下面的 3 个参数的构造方法为例,如下代码,第一个参数是初始容量,第二个参数为负载因子,第三个参数是并发度。

        当初始容量小于并发度的时候,比如说初始容量是 8,并发度是 16,此时就会把初始容量改成并发度,即最少也得保持并发度那么大。

        jdk8 中的 ConcurrentHashMap 是实现了懒惰初始化的,并不在构造方法里面创建出数组,只有在以后第一次用到的时候才会真正创建,在这个构造方法中,仅仅是计算出 table 的大小。

        而 jdk7 中的 ConcurrentHashMap 不管用不用,上来就会创建一个数组,还是比较占用内存的,等于是 jdk8 中是做了改进的。

    public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (initialCapacity < concurrencyLevel) // Use at least as many bins
            initialCapacity = concurrencyLevel; // as estimated threads
        long size = (long)(1.0 + (long)initialCapacity / loadFactor);
        // tableSizeFor 仍然是保证计算的大小是 2^n, 即 16,32,64 ... 
        int cap = (size >= (long)MAXIMUM_CAPACITY) ?
                MAXIMUM_CAPACITY : tableSizeFor((int)size);
        this.sizeCtl = cap;
    }

        计算 size 时比较有意思,假设初始容量是 8,那么 size 就等于 8 除以 0.75,然后 +1,最终等于 11.67 左右,最终 size 被强转等于 11

        下面计算 cap 时调用了 tableSizeFor() 方法,保证最终计算的大小是 2 n 次方。因为后续的一些 hash 算法要求 hash 表的大小必须是 2 n 次方才可以正常工作。

        得出结论:即便我们自己设置了初始容量的大小,但最终创建的数组容量的也不一定是我们设置的那么大。

2.6 get 流程

        get 流程也是 ConcurrentHashMap 中的一个亮点,因为它全称都没有加锁,所以效率很高

    public V get(Object key) {
        Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
        // spread 方法能确保返回结果是正数
        int h = spread(key.hashCode());
        if ((tab = table) != null && (n = tab.length) > 0 &&
                (e = tabAt(tab, (n - 1) & h)) != null) {
            // 如果头结点已经是要查找的 key
            if ((eh = e.hash) == h) {
                if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                    return e.val;
            }
            // hash 为负数表示该 bin 在扩容中或是 treebin, 这时调用 find 方法来查找
            else if (eh < 0)
                return (p = e.find(h, key)) != null ? p.val : null;
            // 正常遍历链表, 用 equals 比较
            while ((e = e.next) != null) {
                if (e.hash == h &&
                        ((ek = e.key) == key || (ek != null && key.equals(ek))))
                    return e.val;
            }
        }
        return null;
    }

        首先调用 spread() 方法,确保返回的结果是个正整数,如果正常调用 hashCode() 方法,结果有可能返回正数也有可能返回负数。返回的这个正数 h 后面用得到。

        然后检查哈希表是否为空即内部的哈希表已经创建好了,并且里面有元素,才会继续向下寻找,否则就返回 null 了;然后调用上面所提到的  tabAt() 方法获取桶下标的头节点,看它是不是不为 null,如果不为 null 则继续比较头节点的哈希码与刚刚 key 的哈希码,若一致,则进一步判断这个 key 是否和我们查找的 key 一致,若是则直接返回。

        第二种情况是头节点是负数,它分为两种情况,第一种是链表处在扩容之中,此时就会调用 find() 方法去新的 table 中找那个 key。第二种情况是存在于树节点中,也是调用重写过的 find() 方法去查找元素。

        最后一种情况是头节点既不是负数,头节点也不是我们想找的 key,此时就会去遍历整个链表,然后进行比较。

        可以看到整个 get() 方法之中,没有任何的锁。

2.7 put 流程

        下面的描述中数组简称为 table,链表简称为 binput() 方法里面会调用 putVal() 方法,其中第三个参数 onlyIfAbsent,如果为 true,只有第一次 put 这个键和值的时才会放入到 map 中,第二次再次 put 时因为已经有了就不做任何操作了,即不会用新值覆盖掉旧值。它的默认值为 false,即每次都会用新值覆盖掉旧值。

    public V put(K key, V value) {

        return putVal(key, value, false);
    }

    final V putVal(K key, V value, boolean onlyIfAbsent) {
        if (key == null || value == null) throw new NullPointerException();
        // 其中 spread 方法会综合高位低位, 具有更好的 hash 性
        int hash = spread(key.hashCode());
        int binCount = 0;
        for (Node<K, V>[] tab = table; ; ) {
            // f 是链表头节点
            // fh 是链表头结点的 hash
            // i 是链表在 table 中的下标
            Node<K, V> f;
            int n, i, fh;
            // 要创建 table
            if (tab == null || (n = tab.length) == 0)
                // 初始化 table 使用了 cas, 无需 synchronized 创建成功, 进入下一轮循环
                tab = initTable();
                // 要创建链表头节点
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                // 添加链表头使用了 cas, 无需 synchronized
                if (casTabAt(tab, i, null,
                        new Node<K, V>(hash, key, value, null)))
                    break;
            }
            // 帮忙扩容
            else if ((fh = f.hash) == MOVED)
                // 帮忙之后, 进入下一轮循环
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
                // 锁住链表头节点
                synchronized (f) {
                    // 再次确认链表头节点没有被移动
                    if (tabAt(tab, i) == f) {
                        // 链表
                        if (fh >= 0) {
                            binCount = 1;
                            // 遍历链表
                            for (Node<K, V> e = f; ; ++binCount) {
                                K ek;
                                // 找到相同的 key
                                if (e.hash == hash &&
                                        ((ek = e.key) == key ||
                                                (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    // 更新
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K, V> pred = e;
                                // 已经是最后的节点了, 新增 Node, 追加至链表尾
                                if ((e = e.next) == null) {

                                    pred.next = new Node<K, V>(hash, key,
                                            value, null);
                                    break;
                                }
                            }
                        }
                        // 红黑树
                        else if (f instanceof TreeBin) {
                            Node<K, V> p;
                            binCount = 2;
                            // putTreeVal 会看 key 是否已经在树中, 是, 则返回对应的 TreeNode
                            if ((p = ((TreeBin<K, V>) f).putTreeVal(hash, key,
                                    value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                    // 释放链表头节点的锁
                }

                if (binCount != 0) {
                    if (binCount >= TREEIFY_THRESHOLD)
                        // 如果链表长度 >= 树化阈值(8), 进行链表转为红黑树
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        // 增加 size 计数
        addCount(1L, binCount);
        return null;
    }

        首先进行非空判断,不允许存 null 键或者 null 值,避免产生二义性。

        然后进入一个死循环,在循环里面判断哈希表是否创建,若没有创建,则创建一个哈希表,创建的过程使用的是 cas,只有一个线程可以创建成功。假设哈希表已经创建好了,此时判断头节点有没有人占,没人占我就占上。

        然后判断头节点是否为 ForwardingNode ,若是,则锁住当前链表帮忙去扩容。

        当桶的下标冲突了,此时就需要加锁,锁住链表的头节点,然后判断头节点是否被移动过,针对于普通节点,如果有 key 则进行更新操作,如果没有则进行追加操作。针对于红黑树节点,则进行更新或者追加操作

2.8 size 计算流程

        size 计算实际发生在 put,remove 改变集合元素的操作之中。没有竞争发生,向 baseCount 累加计数;有竞争发生,新建 counterCells,向其中的一个 cell 累加计数,counterCells 初始有两个 cell,如果计数竞争比较激烈,会创建新的 cell 来累加计数,如下代码:

    public int size() {
        long n = sumCount();
        return ((n < 0L) ? 0 :
                (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
                        (int)n);
    }
    final long sumCount() {
        CounterCell[] as = counterCells; CounterCell a;
        // 将 baseCount 计数与所有 cell 计数累加
        long sum = baseCount;
        if (as != null) {
            for (int i = 0; i < as.length; ++i) {
                if ((a = as[i]) != null)
                    sum += a.value;
            }
        }
        return sum;
    }

        在 ConcurrentHashMap 中,计数是有一定的误差的,因为是多线程访问 map 集合,有的是增加有的是删除,所以计数不是特别准确,虽然做了一些优化,但得到的也就说大概值。

2.9 小结

        针对于 jdk1.8 来说,底层是数组 + 链表 | 红黑树构成。

        初始化,使用 cas 来保证并发安全,懒惰初始化 table 数组

        树化,当 table.length < 64 时,先尝试扩容,超过 64 时,并且链表的 length > 8 时,会将链表树化,树化过程会用 synchronized 锁住链表头。

        put,如果该链表尚未创建,只需要使用 cas 创建链表;如果已经有了,锁住链表头进行后续 put 操作,元素添加至链表的尾部。

        get,无锁操作仅需要保证可见性,扩容过程中 get 操作拿到的是 ForwardingNode 它会让 get 操作在新 table 进行搜索。

        扩容,扩容时以链表为单位进行,需要对链表进行 synchronized,但这时妙的是其它竞争线程也不是无事可做,它们会帮助把其它链表进行扩容,扩容时平均只有 1/6 的节点会把复制到新 table 中。

        size,元素个数保存在 baseCount 中,并发时的个数变动保存在 CounterCell[] 当中。最后统计数量时累加即可。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

快乐的小三菊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值