RocketMQ 提供了一致性hash 算法来做Consumer 和 MessageQueue的负载均衡。 源码中一致性hash 环的实现是很优秀的,我们一步一步分析。
-
一个Hash环包含多个节点, 我们用 MyNode 去封装节点, 方法 getKey() 封装获取节点的key。我们可以实现MyNode 去描述一个物理节点或虚拟节点。MyVirtualNode 实现 MyNode, 表示一个虚拟节点。这里注意:一个虚拟节点是依赖于一个物理节点,所以MyVirtualNode 中封装了 一个 泛型 T physicalNode。物理节点MyClientNode也是实现了这个MyNode接口,很好的设计。代码加注释如下:
/** * 表示hash环的一个节点 */ public interface MyNode { /** * @return 节点的key */ String getKey(); } /** * 虚拟节点 */ public class MyVirtualNode<T extends MyNode> implements MyNode { final T physicalNode; // 主节点 final int replicaIndex; // 虚节点下标 public MyVirtualNode(T physicalNode, int replicaIndex) { this.physicalNode = physicalNode; this.replicaIndex = replicaIndex; } @Override public String getKey() { return physicalNode.getKey() + "-" + replicaIndex; } /** * thisMyVirtualNode 是否是pNode 的 虚节点 */ public boolean isVirtualNodeOf(T pNode) { return physicalNode.getKey().equals(pNode.getKey()); } public T getPhysicalNode() { return physicalNode; } } private static class MyClientNode implements MyNode { private final String clientID; public MyClientNode(String clientID) { this.clientID = clientID; } @Override public String getKey() { return clientID; } }
-
上面实现了节点, 一致性hash 下一个问题是怎么封装hash算法呢?RocketMQ 使用 MyHashFunction 接口定义hash算法。使用MD5 + bit 位hash的方式实现hash算法。我们完全可以自己实现hash算法,具体见我的“常见的一些hash函数”文章。MyMD5Hash 算法代码的如下:
// MD5 hash 算法, 这里hash算法可以用常用的 hash 算法替换。 private static class MyMD5Hash implements MyHashFunction { MessageDigest instance; public MyMD5Hash() { try { instance = MessageDigest.getInstance("MD5"); } catch (NoSuchAlgorithmException e) { } } @Override public long hash(String key) { instance.reset(); instance.update(key.getBytes()); byte[] digest = instance.digest(); long h = 0; for (int i = 0; i < 4; i++) { h <<= 8; h |= ((int)digest[i]) & 0xFF; } return h; } }
-
现在,hash环的节点有了, hash算法也有了,最重要的是描述一个一致性hash 环。 想一想,这个环可以由N 个物理节点, 每个物理节点对应m个虚拟节点,节点位置用hash算法值描述。每个物理节点就是每个Consumer, 每个Consumer 的 id 就是 物理节点的key。 每个MessageQueue 的toString() 值 hash 后,用来找环上对应的最近的下一个物理节点。源码如下,这里展示主要的代码,其中最巧妙地是routeNode 方法, addNode 方法 注意我的注释:
public class MyConsistentHashRouter<T extends MyNode> { private final SortedMap<Long, MyVirtualNode<T>> ring = new TreeMap<>(); // key是虚节点key的哈希值, value 是虚节点 private final MyHashFunction myHashFunction; /** * @param pNodes 物理节点集合 * @param vNodeCount 每个物理节点对应的虚节点数量 * @param hashFunction hash 函数 用于 hash 各个节点 */ public MyConsistentHashRouter(Collection<T> pNodes, int vNodeCount, MyHashFunction hashFunction) { if (hashFunction == null) { throw new NullPointerException("Hash Function is null"); } this.myHashFunction = hashFunction; if (pNodes != null) { for (T pNode : pNodes) { this.addNode(pNode, vNodeCount); } } } /** * 添加物理节点和它的虚节点到hash环。 * @param pNode 物理节点 * @param vNodeCount 虚节点数量。 */ public void addNode(T pNode, int vNodeCount) { if (vNodeCount < 0) { throw new IllegalArgumentException("ill virtual node counts :" + vNodeCount); } int existingReplicas = this.getExistingReplicas(pNode); for (int i = 0; i < vNodeCount; i++) { MyVirtualNode<T> vNode = new MyVirtualNode<T>(pNode, i + existingReplicas); // 创建一个新的虚节点,位置是 i+existingReplicas ring.put(this.myHashFunction.hash(vNode.getKey()), vNode); // 将新的虚节点放到hash环中 } } /** * 根据一个给定的key 在 hash环中 找到离这个key最近的下一个物理节点 * @param key 一个key, 用于找这个key 在环上最近的节点 */ public T routeNode(String key) { if (ring.isEmpty()) { return null; } Long hashVal = this.myHashFunction.hash(key); SortedMap<Long, MyVirtualNode<T>> tailMap = ring.tailMap(hashVal); Long nodeHashVal = !tailMap.isEmpty() ? tailMap.firstKey() : ring.firstKey(); return ring.get(nodeHashVal).getPhysicalNode(); } /** * @param pNode 物理节点 * @return 当前这个物理节点对应的虚节点的个数 */ public int getExistingReplicas(T pNode) { int replicas = 0; for (MyVirtualNode<T> vNode : ring.values()) { if (vNode.isVirtualNodeOf(pNode)) { replicas++; } } return replicas; }
-
现在一致性hash 环有了, 剩下的就是 和rocketmq 的 consumer, mq 构成负载均衡策略了。比较简单, 代码如下:
/** * 基于一致性性hash环的Consumer负载均衡. */ public class MyAllocateMessageQueueConsistentHash implements AllocateMessageQueueStrategy { // 每个物理节点对应的虚节点的个数 private final int virtualNodeCnt; private final MyHashFunction customHashFunction; public MyAllocateMessageQueueConsistentHash() { this(10); // 默认10个虚拟节点 } public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt) { this(virtualNodeCnt, null); } public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt, MyHashFunction customHashFunction) { if (virtualNodeCnt < 0) { throw new IllegalArgumentException("illegal virtualNodeCnt : " + virtualNodeCnt); } this.virtualNodeCnt = virtualNodeCnt; this.customHashFunction = customHashFunction; } @Override public List<MessageQueue> allocate(String consumerGroup, String currentCID, List<MessageQueue> mqAll, List<String> cidAll) { // 省去一系列非空校验 Collection<MyClientNode> cidNodes = new ArrayList<>(); for (String cid : cidAll) { cidNodes.add(new MyClientNode(cid)); } final MyConsistentHashRouter<MyClientNode> router; if (this.customHashFunction != null) { router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt, customHashFunction); }else { router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt); } List<MessageQueue> results = new ArrayList<MessageQueue>(); // 当前 currentCID 对应的 mq // 将每个mq 根据一致性hash 算法找到对应的物理节点(Consumer) for (MessageQueue mq : mqAll) { MyClientNode clientNode = router.routeNode(mq.toString()); // 根据 mq toString() 方法做hash 和环上节点比较 if (clientNode != null && currentCID.equals(clientNode.getKey())) { results.add(mq); } } return results; } @Override public String getName() { return "CONSISTENT_HASH"; } private static class MyClientNode implements MyNode { private final String clientID; public MyClientNode(String clientID) { this.clientID = clientID; } @Override public String getKey() { return clientID; } } }