RocketMQ 提供了一致性hash 算法来做Consumer 和 MessageQueue的负载均衡。 源码中一致性hash 环的实现是很优秀的,我们一步一步分析。
一个Hash环包含多个节点, 我们用 MyNode 去封装节点, 方法 getKey() 封装获取节点的key。我们可以实现MyNode 去描述一个物理节点或虚拟节点。MyVirtualNode 实现 MyNode, 表示一个虚拟节点。这里注意:一个虚拟节点是依赖于一个物理节点,所以MyVirtualNode 中封装了 一个 泛型 T physicalNode。物理节点MyClientNode也是实现了这个MyNode接口,很好的设计。代码加注释如下:
/**
* 表示hash环的一个节点
*/
public interface MyNode {
/**
* @return 节点的key
*/
String getKey();
}
/**
* 虚拟节点
*/
public class MyVirtualNode<T extends MyNode> implements MyNode {
final T physicalNode; // 主节点
final int replicaIndex; // 虚节点下标
public MyVirtualNode(T physicalNode, int replicaIndex) {
this.physicalNode = physicalNode;
this.replicaIndex = replicaIndex;
}
@Override
public String getKey() {
return physicalNode.getKey() + "-" + replicaIndex;
}
/**
* thisMyVirtualNode 是否是pNode 的 虚节点
*/
public boolean isVirtualNodeOf(T pNode) {
return physicalNode.getKey().equals(pNode.getKey());
}
public T getPhysicalNode() {
return physicalNode;
}
}
private static class MyClientNode implements MyNode {
private final String clientID;
public MyClientNode(String clientID) {
this.clientID = clientID;
}
@Override
public String getKey() {
return clientID;
}
}
上面实现了节点, 一致性hash 下一个问题是怎么封装hash算法呢?RocketMQ 使用 MyHashFunction 接口定义hash算法。使用MD5 + bit 位hash的方式实现hash算法。我们完全可以自己实现hash算法,具体见我的“常见的一些hash函数”文章。MyMD5Hash 算法代码的如下:
// MD5 hash 算法, 这里hash算法可以用常用的 hash 算法替换。
private static class MyMD5Hash implements MyHashFunction {
MessageDigest instance;
public MyMD5Hash() {
try {
instance = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
}
}
@Override
public long hash(String key) {
instance.reset();
instance.update(key.getBytes());
byte[] digest = instance.digest();
long h = 0;
for (int i = 0; i < 4; i++) {
h <<= 8;
h |= ((int)digest[i]) & 0xFF;
}
return h;
}
}
现在,hash环的节点有了, hash算法也有了,最重要的是描述一个一致性hash 环。 想一想,这个环可以由N 个物理节点, 每个物理节点对应m个虚拟节点,节点位置用hash算法值描述。每个物理节点就是每个Consumer, 每个Consumer 的 id 就是 物理节点的key。 每个MessageQueue 的toString() 值 hash 后,用来找环上对应的最近的下一个物理节点。源码如下,这里展示主要的代码,其中最巧妙地是routeNode 方法, addNode 方法 注意我的注释:
public class MyConsistentHashRouter<T extends MyNode> {
private final SortedMap<Long, MyVirtualNode<T>> ring = new TreeMap<>(); // key是虚节点key的哈希值, value 是虚节点
private final MyHashFunction myHashFunction;
/**
* @param pNodes 物理节点集合
* @param vNodeCount 每个物理节点对应的虚节点数量
* @param hashFunction hash 函数 用于 hash 各个节点
*/
public MyConsistentHashRouter(Collection<T> pNodes, int vNodeCount, MyHashFunction hashFunction) {
if (hashFunction == null) {
throw new NullPointerException("Hash Function is null");
}
this.myHashFunction = hashFunction;
if (pNodes != null) {
for (T pNode : pNodes) {
this.addNode(pNode, vNodeCount);
}
}
}
/**
* 添加物理节点和它的虚节点到hash环。
* @param pNode 物理节点
* @param vNodeCount 虚节点数量。
*/
public void addNode(T pNode, int vNodeCount) {
if (vNodeCount < 0) {
throw new IllegalArgumentException("ill virtual node counts :" + vNodeCount);
}
int existingReplicas = this.getExistingReplicas(pNode);
for (int i = 0; i < vNodeCount; i++) {
MyVirtualNode<T> vNode = new MyVirtualNode<T>(pNode, i + existingReplicas); // 创建一个新的虚节点,位置是 i+existingReplicas
ring.put(this.myHashFunction.hash(vNode.getKey()), vNode); // 将新的虚节点放到hash环中
}
}
/**
* 根据一个给定的key 在 hash环中 找到离这个key最近的下一个物理节点
* @param key 一个key, 用于找这个key 在环上最近的节点
*/
public T routeNode(String key) {
if (ring.isEmpty()) {
return null;
}
Long hashVal = this.myHashFunction.hash(key);
SortedMap<Long, MyVirtualNode<T>> tailMap = ring.tailMap(hashVal);
Long nodeHashVal = !tailMap.isEmpty() ? tailMap.firstKey() : ring.firstKey();
return ring.get(nodeHashVal).getPhysicalNode();
}
/**
* @param pNode 物理节点
* @return 当前这个物理节点对应的虚节点的个数
*/
public int getExistingReplicas(T pNode) {
int replicas = 0;
for (MyVirtualNode<T> vNode : ring.values()) {
if (vNode.isVirtualNodeOf(pNode)) {
replicas++;
}
}
return replicas;
}
现在一致性hash 环有了, 剩下的就是 和rocketmq 的 consumer, mq 构成负载均衡策略了。比较简单, 代码如下:
/**
* 基于一致性性hash环的Consumer负载均衡.
*/
public class MyAllocateMessageQueueConsistentHash implements AllocateMessageQueueStrategy {
// 每个物理节点对应的虚节点的个数
private final int virtualNodeCnt;
private final MyHashFunction customHashFunction;
public MyAllocateMessageQueueConsistentHash() {
this(10); // 默认10个虚拟节点
}
public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt) {
this(virtualNodeCnt, null);
}
public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt, MyHashFunction customHashFunction) {
if (virtualNodeCnt < 0) {
throw new IllegalArgumentException("illegal virtualNodeCnt : " + virtualNodeCnt);
}
this.virtualNodeCnt = virtualNodeCnt;
this.customHashFunction = customHashFunction;
}
@Override
public List<MessageQueue> allocate(String consumerGroup, String currentCID, List<MessageQueue> mqAll, List<String> cidAll) {
// 省去一系列非空校验
Collection<MyClientNode> cidNodes = new ArrayList<>();
for (String cid : cidAll) {
cidNodes.add(new MyClientNode(cid));
}
final MyConsistentHashRouter<MyClientNode> router;
if (this.customHashFunction != null) {
router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt, customHashFunction);
}else {
router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt);
}
List<MessageQueue> results = new ArrayList<MessageQueue>(); // 当前 currentCID 对应的 mq
// 将每个mq 根据一致性hash 算法找到对应的物理节点(Consumer)
for (MessageQueue mq : mqAll) {
MyClientNode clientNode = router.routeNode(mq.toString()); // 根据 mq toString() 方法做hash 和环上节点比较
if (clientNode != null && currentCID.equals(clientNode.getKey())) {
results.add(mq);
}
}
return results;
}
@Override
public String getName() {
return "CONSISTENT_HASH";
}
private static class MyClientNode implements MyNode {
private final String clientID;
public MyClientNode(String clientID) {
this.clientID = clientID;
}
@Override
public String getKey() {
return clientID;
}
}
}
————————————————
版权声明:本文为CSDN博主「昊haohao」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/ZHANGYONGHAO604/article/details/82426373