自定义协议的要素:
- 魔数,用来在第一时间判定是否是无效数据包
- 版本号,可以支持协议的升级
- 序列化算法,消息正文到底采用哪种序列化反序列化方式,可以由此扩展,例如:json、hessian、jdk
- 指令类型,是请求还是响应或其他跟业务相关
- 请求序号,为了双工通信,提供异步能力,通过这个请求ID将响应关联起来,也可以通过请求ID做链路追踪。
- 正文长度,标注传输数据内容的长度,用于判断是否是一个完整的数据包
- 消息正文,主要传递的消息内容
魔数的作用
通信协议中的"魔数"是指协议中预定义的一个固定的标识符,通常是一个固定长度的字节序列。在实际通信中,数据包的开头会包含这个固定的魔数,用来表明该数据包所使用的通信协议类型。
魔数的作用主要有以下几点:
- 协议识别:魔数可以帮助接收方快速识别数据包所使用的通信协议类型。接收方在接收到数据包后,首先会读取数据包中的魔数,并与预定义的协议魔数进行比对,以确定所使用的通信协议类型。
- 安全性:通过魔数可以增加通信协议的安全性。在通信过程中,接收方可以根据魔数判断数据包的合法性,从而避免恶意数据包的攻击。
- 版本兼容性:魔数还可以用于确定通信协议的版本,从而实现不同版本之间的兼容性。接收方可以根据魔数判断发送方所使用的协议版本,以便做出相应的处理和解析。
/*
+-----------------------------------------------------------------+
| 魔数 4byte | 协议版本号 1byte | 序列化算法 1byte | 指令类型 1byte |
+-----------------------------------------------------------------+
| 请求序列 4byte | 状态 1byte | 数据长度 4byte | 消息 ID 不定 |
+-----------------------------------------------------------------+
*/
编码器
@Slf4j
public class RpcEncoder<T> extends MessageToByteEncoder<MessageProtocol<T>> {
/*
+-----------------------------------------------------------------+
| 魔数 4byte | 协议版本号 1byte | 序列化算法 1byte | 指令类型 1byte |
+-----------------------------------------------------------------+
| 请求序列 4byte | 状态 1byte | 数据长度 4byte | 消息 ID 不定 |
+-----------------------------------------------------------------+
*/
/**
*
* @param byteBuf
* @throws Exception
*/
@Override
protected void encode(ChannelHandlerContext ctx, MessageProtocol<T> msg, ByteBuf byteBuf) throws Exception {
MessageHeader header = msg.getMessageHeader();
T data = msg.getData();
// 魔数
byteBuf.writeInt(header.getMagicNum());// 4byte
// 版本号
byteBuf.writeByte(header.getVersion());// 1byte
// 序列化算法类型 0:jdk 1:json
// ordinal可以得到顺序
byte serializeAlgorithm = header.getSerializeAlgorithm();
byteBuf.writeByte(serializeAlgorithm);// 1byte
// 指令类型
byteBuf.writeByte(header.getMessageType());// 1byte
//请求序列
byteBuf.writeInt(header.getSequenceId());// 4byte
// 状态
byteBuf.writeByte(header.getStatus());// 1byte
// 获取内容字节数组
byte[] byteArray = Config.getSerializerAlgorithm(serializeAlgorithm).serialize(data);
// 数据长度
byteBuf.writeInt(byteArray.length);// 4byte
// 写入内容
byteBuf.writeBytes(byteArray);
}
}
解码器
@Slf4j
public class RpcDecoder extends ByteToMessageDecoder {
/*
+-----------------------------------------------------------------+
| 魔数 4byte | 协议版本号 1byte | 序列化算法 1byte | 指令类型 1byte |
+-----------------------------------------------------------------+
| 请求序列 4byte | 状态 1byte | 数据长度 4byte | 消息 ID 不定 |
+-----------------------------------------------------------------+
*/
/**
*
* @param ctx
* @param in
* @param out
* @throws Exception
*/
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
// 读四个字节的魔数
int magicNum = in.readInt();
// 读一个字节版本号
byte version = in.readByte();
// 读一个字节序列化算法类型
byte serializeAlgorithm = in.readByte();
// 读一个字节指令类型
byte messageType = in.readByte();
// 读四个字节请求序列
int sequenceId = in.readInt();
byte status = in.readByte();// 读一个字节填充数
// 数据长度
int len = in.readInt();
// 读出内容
byte[] byteArray = new byte[len];
in.readBytes(byteArray, 0, len);
MessageHeader header = new MessageHeader();
header.setMagicNum(magicNum);
header.setVersion(version);
header.setSerializeAlgorithm(serializeAlgorithm);
header.setStatus(status);
header.setSequenceId(sequenceId);
header.setMessageType(messageType);
header.setDataLength(len);
// 获取反序列化算法
Serializer.Algorithm algorithm = Serializer.Algorithm.values()[serializeAlgorithm];
// 反序列化
if (messageType == MsgType.REQUEST.getType()) {
RpcRequest request = algorithm.deserialize(RpcRequest.class, byteArray);// 反序列化
MessageProtocol<RpcRequest> protocol = new MessageProtocol<>();
protocol.setMessageHeader(header);
protocol.setData(request);
out.add(protocol);
} else if (messageType == MsgType.RESPONSE.getType()) {
RpcResponse response = algorithm.deserialize(RpcResponse.class, byteArray);// 反序列化
MessageProtocol<RpcResponse> protocol = new MessageProtocol<>();
protocol.setMessageHeader(header);
protocol.setData(response);
out.add(protocol);
}
}
}
或者使用一体编解码器
一体编解码器
/**
* 需要配合LengthFieldBasedFrameDecoder 编码器使用
* Sharable可以被EventLoop共享使用
* MessageToMessageCodec 完整的消息,转为对象a
* @description: 自定义消息编解码器
* @author xiaonan
* @date 2024/2/18
*
*/
@ChannelHandler.Sharable
public class MyMessageCodecSharable<T> extends MessageToMessageCodec<ByteBuf, MessageProtocol<T>> {
/*
+-----------------------------------------------------------------+
| 魔数 4byte | 协议版本号 1byte | 序列化算法 1byte | 指令类型 1byte |
+-----------------------------------------------------------------+
| 请求序列 4byte | 状态 1byte | 数据长度 4byte | 消息 ID 不定 |
+-----------------------------------------------------------------+
*/
@Override
protected void encode(ChannelHandlerContext ctx, MessageProtocol<T> msg, List<Object> list) throws Exception {
MessageHeader header = msg.getMessageHeader();
T data = msg.getData();
ByteBuf out = ByteBufAllocator.DEFAULT.buffer();
// 魔数
out.writeInt(header.getMagicNum());// 4byte
// 版本号
out.writeByte(header.getVersion());// 1byte
// 序列化算法类型 0:jdk 1:json
// ordinal可以得到顺序
byte serializeAlgorithm = header.getSerializeAlgorithm();
out.writeByte(serializeAlgorithm);// 1byte
// 指令类型
out.writeByte(header.getMessageType());// 1byte
//请求序列
out.writeInt(header.getSequenceId());// 4byte
// 状态
out.writeByte(header.getStatus());// 1byte
// 获取内容字节数组
byte[] byteArray = Config.getSerializerAlgorithm(serializeAlgorithm).serialize(data);
// 数据长度
out.writeInt(byteArray.length);// 4byte
// 写入内容
out.writeBytes(byteArray);
list.add(out);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
// 读四个字节的魔数
int magicNum = in.readInt();
// 读一个字节版本号
byte version = in.readByte();
// 读一个字节序列化算法类型
byte serializeAlgorithm = in.readByte();
// 读一个字节指令类型
byte messageType = in.readByte();
// 读四个字节请求序列
int sequenceId = in.readInt();
byte status = in.readByte();// 读一个字节填充数
// 数据长度
int len = in.readInt();
// 读出内容
byte[] byteArray = new byte[len];
in.readBytes(byteArray, 0, len);
MessageHeader header = new MessageHeader();
header.setMagicNum(magicNum);
header.setVersion(version);
header.setSerializeAlgorithm(serializeAlgorithm);
header.setStatus(status);
header.setSequenceId(sequenceId);
header.setMessageType(messageType);
header.setDataLength(len);
// 获取反序列化算法
Serializer.Algorithm algorithm = Serializer.Algorithm.values()[serializeAlgorithm];
// 反序列化
if(messageType == MsgType.REQUEST.getType()){
RpcRequest request = algorithm.deserialize(RpcRequest.class, byteArray);// 反序列化
MessageProtocol<RpcRequest> protocol = new MessageProtocol<>();
protocol.setMessageHeader(header);
protocol.setData(request);
out.add(protocol);
}
else if(messageType == MsgType.RESPONSE.getType()){
RpcResponse response = algorithm.deserialize(RpcResponse.class, byteArray);// 反序列化
MessageProtocol<RpcResponse> protocol = new MessageProtocol<>();
protocol.setMessageHeader(header);
protocol.setData(response);
out.add(protocol);
}
}
}