继续上一节的内容,解析代码。
编码器
参考LengthFieldBasedFrameDecoder解码器的协议,在协议里规定传输哪些类型的数据, 以及每一种类型的数据应该占多少字节。这样我们在接收到二级制数据之后,就可以正确的解析出我们需要的数据。
下面是本次使用的传输协议:
* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
* +-----+-----+-----+-----+--------+----+----+----+------+-----------+-------+----- --+-----+-----+-------+
* | magic code |version | full length | messageType| codec|compress| RequestId |
* +-----------------------+--------+---------------------+-----------+-----------+-----------+------------+
* | |
* | body |
* | |
* | ... ... |
* +-------------------------------------------------------------------------------------------------------+
* 4B magic code(魔法数) 1B version(版本) 4B full length(消息长度) 1B messageType(消息类型)
* 1B compress(压缩类型) 1B codec(序列化类型) 4B requestId(请求的Id)
* body(object类型数据)
首先是RpcMessageEncoder.java,这个RpcMessageEncoder类的主要作用是将RpcMessage编码为字节,以便可以通过网络发送。它首先将RpcMessage的各个字段(如魔法数、版本号、消息类型等)写入到一个ByteBuf中,然后如果消息类型不是心跳请求类型和心跳响应类型,它还会将消息数据序列化和压缩,然后将序列化和压缩后的字节数组写入到ByteBuf中。最后,它会在ByteBuf的适当位置写入消息的全长度。
@Slf4j
public class RpcMessageEncoder extends MessageToByteEncoder<RpcMessage> {
private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0);// 定义一个原子整数,用于生成请求ID
@Override
protected void encode(ChannelHandlerContext ctx, RpcMessage rpcMessage, ByteBuf out) {
// 当需要将RpcMessage编码为字节时被调用
try {
out.writeBytes(RpcConstants.MAGIC_NUMBER);// 写入魔法数(常量)
out.writeByte(RpcConstants.VERSION);// 写入版本号(常量)
out.writerIndex(out.writerIndex() + 4);// 留出一个位置来写入消息的全长度
byte messageType = rpcMessage.getMessageType();// 获取消息类型
out.writeByte(messageType);// 写入消息类型
out.writeByte(rpcMessage.getCodec());// 写入编解码类型 hessian、kyro或protostuff
out.writeByte(CompressTypeEnum.GZIP.getCode());// 写入压缩类型
out.writeInt(ATOMIC_INTEGER.getAndIncrement());// 写入请求ID,并将原子整数加1
byte[] bodyBytes = null;// 定义一个字节数组来存储消息体
int fullLength = RpcConstants.HEAD_LENGTH;// 定义一个整数来存储消息的全长度,初始值为头部长度16
// 消息类型不是心跳消息,则全长=头部长度+正文长度
if (messageType != RpcConstants.HEARTBEAT_REQUEST_TYPE
&& messageType != RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
// 如果消息类型不是心跳请求类型和心跳响应类型
// 序列化对象
String codecName = SerializationTypeEnum.getName(rpcMessage.getCodec());// 获取编解码类型的名字
log.info("codec name: [{}] ", codecName);
Serializer serializer = ExtensionLoader.getExtensionLoader(Serializer.class)
.getExtension(codecName);// 通过ExtensionLoader加载扩展类——序列化器
bodyBytes = serializer.serialize(rpcMessage.getData());// 将消息数据序列化为字节数组
// 压缩字节数组
String compressName = CompressTypeEnum.getName(rpcMessage.getCompress());// 获取压缩类型的名字
Compress compress = ExtensionLoader.getExtensionLoader(Compress.class)
.getExtension(compressName);// 通过ExtensionLoader加载扩展类——压缩器
bodyBytes = compress.compress(bodyBytes);// 将字节数组压缩
fullLength += bodyBytes.length;// 将字节数组的长度加到消息的全长度上
}
if (bodyBytes != null) {
out.writeBytes(bodyBytes);// 如果字节数组不为空,就将字节数组写入到输出中
}
int writeIndex = out.writerIndex();// 获取写入的索引
//回退到消息长度字段的位置,以便写入消息的全长度。
out.writerIndex(writeIndex - fullLength + RpcConstants.MAGIC_NUMBER.length + 1);// 设置写入的索引到合适位置
out.writeInt(fullLength);// 写入消息的全长度
out.writerIndex(writeIndex);// 恢复写入的索引
} catch (Exception e) {
log.error("Encode request error!", e);
}
}
}
然后是RpcMessageDecoder.java,这个RpcMessageDecoder类的主要作用是将字节解码为RpcMessage。它首先从ByteBuf中读取各个字段(如魔法数、版本号、全长度等),然后根据消息类型,可能会从ByteBuf中读取消息体,然后将消息体解压缩和反序列化,最后将反序列化后的对象设置到RpcMessage的数据中。
@Slf4j
public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {
public RpcMessageDecoder() {
// 调用父类的构造函数,设置各个参数
// lengthFieldOffset: 魔法数是4B,版本是1B,然后才是消息长度。所以值是5
// lengthFieldLength: 消息长度是4B。所以值是4
// lengthAdjustment: 消息长度加上之前读取的所有数据,9个字节,所以剩下的长度是(fullLength-9)。所以值是-9
// initialBytesToStrip: 我们将手动检查魔术代码和版本,所以不要剥离任何字节。因此值为0
this(RpcConstants.MAX_FRAME_LENGTH, 5, 4, -9, 0);
}
/**
* @param maxFrameLength 最大帧长度。它决定了可以接收的数据的最大长度。
* 如果超过,数据将被丢弃。
* @param lengthFieldOffset 这是长度字段的偏移量。也就是说,数据帧的开始到消息长度的开始的字节数。
* @param lengthFieldLength 消息长度的调整值。
* @param lengthAdjustment 消息长度补偿值。lengthAdjustment +数据长度取值 = 数据长度字段之后剩下包的字节数
* @param initialBytesToStrip 需要剥离ByteBuf的长度(一般为0)
* 如果需要接收所有标头+正文数据,则此值为0
* 如果只想接收正文数据,则需要跳过标头所消耗的字节数。
*/
public RpcMessageDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength,
int lengthAdjustment, int initialBytesToStrip) {
// 调用父类的构造函数,设置各个参数
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
}
// 当需要将字节解码为RpcMessage时被调用
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
Object decoded = super.decode(ctx, in);// 调用父类的decode方法,获取解码后的对象
if (decoded instanceof ByteBuf) {
// 如果解码后的对象是ByteBuf类型
ByteBuf frame = (ByteBuf) decoded;// 将解码后的对象转换为ByteBuf类型
if (frame.readableBytes() >= RpcConstants.TOTAL_LENGTH) {
// 如果ByteBuf中可读的字节数大于或等于16 16是所有消息头的长度和
try {
return decodeFrame(frame);// 解码帧
} catch (Exception e) {
log.error("Decode frame error!", e);
throw e;
} finally {
frame.release();// 释放帧
}
}
}
return decoded;
}
private Object decodeFrame(ByteBuf in) {
// 解码帧
// note: must read ByteBuf in order
checkMagicNumber(in);// 检查魔法数
checkVersion(in);// 检查版本号
int fullLength = in.readInt();// 读取消息长度
// build RpcMessage object
byte messageType = in.readByte();// 读取消息类型
byte codecType = in.readByte();// 读取编解码类型
byte compressType = in.readByte();// 读取压缩类型
int requestId = in.readInt();// 读取请求ID
RpcMessage rpcMessage = RpcMessage.builder()//构建RpcMessage
.codec(codecType)
.requestId(requestId)
.messageType(messageType).build();
if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) {
// 如果消息类型是心跳请求类型
rpcMessage.setData(RpcConstants.PING);
return rpcMessage;
}
if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
// 如果消息类型是心跳响应类型
rpcMessage.setData(RpcConstants.PONG);
return rpcMessage;
}
int bodyLength = fullLength - RpcConstants.HEAD_LENGTH;// 计算消息体data的长度
if (bodyLength > 0) {
// 如果消息体的长度大于0
byte[] bs = new byte[bodyLength];// 创建一个新的字节数组来存储消息体
in.readBytes(bs);// 从ByteBuf中读取字节到字节数组中
// 解压字节数组
String compressName = CompressTypeEnum.getName(compressType);// 获取压缩类型的名字
Compress compress = ExtensionLoader.getExtensionLoader(Compress.class)
.getExtension(compressName);// 通过ExtensionLoader加载扩展类——压缩器
bs = compress.decompress(bs);// 将字节数组解压缩
// 反序列化
String codecName = SerializationTypeEnum.getName(rpcMessage.getCodec())