gitee:https://gitee.com/niumazlb/mynettydemo
github:https://github.com/ishuaige/mynettydemo
参考:黑马程序员Netty教程
服务端
这里用 netty+WebSocket 举例
首先我们需要一个服务端 WebSocketServer
创建服务端大致步骤:
- 创建 ServerBootstrap - bootstrap
- bootstrap 关联两个事件循环组(EventLoopGroup) boss 和 work
- bootstrap 配置项 option
- bootstrap 绑定 channel
- bootstrap 绑定 handler,这里可以做一个封装
- bootstrap 绑定 绑定端口
- bootstrap 绑定监听关闭连接事件,处理 boss 和 work 的关闭
以上 2-5 顺序随意
import com.niuma.mynetty.server.handler.MyWebSocketChannelHandler;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
/**
* @author niumazlb
* @create 2023-01-09 20:30
*/
@ChannelHandler.Sharable
@Component
public class WebSocketServer {
EventLoopGroup boss = new NioEventLoopGroup();
EventLoopGroup work = new NioEventLoopGroup();
@Resource
MyWebSocketChannelHandler myWebSocketChannelHandler;
public void run() {
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(boss, work);
bootstrap.option(ChannelOption.SO_BACKLOG, 1024);
bootstrap.channel(NioServerSocketChannel.class);
//将channelHandler做一个封装
bootstrap.childHandler(myWebSocketChannelHandler);
Channel channel = bootstrap.bind(8888).sync().channel();
channel.closeFuture().addListener(future -> {
boss.shutdownGracefully();
work.shutdownGracefully();
});
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
注意 Handler 的继承喔
- 这里的 Handler 是处理的 websocket 的,其他需要具体情况具体分析
package com.niuma.mynetty.server.handler;
import com.niuma.mynetty.config.NettyConfig;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
/**
* 用于初始化注册 Channel 中的各个 Handler
* @author niumazlb
* @create 2023-01-09 20:35
*/
@Component
public class MyWebSocketChannelHandler extends ChannelInitializer<SocketChannel> {
@Resource
MyWebSocketHandler myWebSocketHandler ;
@Resource
RegisterHandler registerHandler ;
@Resource
SingleMessageHandler singleMessageHandler ;
@Override
protected void initChannel(SocketChannel ch) throws Exception {
//webSocket连接基本就是以下的几个handler
//http服务端的解码器
ch.pipeline().addLast("http-codec", new HttpServerCodec())
//日志
.addLast("log",new LoggingHandler(LogLevel.DEBUG))
/*
通过 HttpObjectAggregator 可以把 HttpMessage 和 HttpContent 聚合成一个 FullHttpRequest,
并定义可以接受的数据大小,在文件上传时,可以支持params+multipart
*/
.addLast("aggregator",new HttpObjectAggregator(65536)) //httpContent消息聚合
//块写入写出Handler
.addLast("http-chunked",new ChunkedWriteHandler()) // HttpContent 压缩
/*
netty内置的WebSocketServerProtocolHandler作为Websocket协议的主要处理器
是处理HTTP相关的数据的
*/
.addLast("protocolHandler",new WebSocketServerProtocolHandler("/websocket"))
// 到此数据都处理完了,后面就是我们定义的处理器了==============================
// 用于消息的分发,接收 WebSocketFrame 客户端传来的帧
.addLast(myWebSocketHandler)
// 处理各种的消息
.addLast(registerHandler)
.addLast(singleMessageHandler);
}
//客户端与服务端创建连接
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
NettyConfig.group.add(ctx.channel());
System.out.println("客户端与服务端连接开启....");
}
//客户端与服务端断开连接
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
NettyConfig.group.remove(ctx.channel());
System.out.println("客户端与服务端连接关闭....");
}
//接收结束之后 read相对于服务端
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
// 出现异常时调用
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
}
简单的 RPC 框架
参考黑马程序员 netty 课程
协议
package cn.itcast.protocol;
import cn.itcast.config.Config;
import cn.itcast.message.Message;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageCodec;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
@ChannelHandler.Sharable
/**
* 必须和 LengthFieldBasedFrameDecoder 一起使用,确保接到的 ByteBuf 消息是完整的
*/
public class MessageCodecSharable extends MessageToMessageCodec<ByteBuf, Message> {
@Override
public void encode(ChannelHandlerContext ctx, Message msg, List<Object> outList) throws Exception {
ByteBuf out = ctx.alloc().buffer();
// 1. 4 字节的魔数
out.writeBytes(new byte[]{1, 2, 3, 4});
// 2. 1 字节的版本,
out.writeByte(1);
// 3. 1 字节的序列化方式 jdk 0 , json 1
out.writeByte(Config.getSerializerAlgorithm().ordinal());
// 4. 1 字节的指令类型
out.writeByte(msg.getMessageType());
// 5. 4 个字节
out.writeInt(msg.getSequenceId());
// 无意义,对齐填充
out.writeByte(0xff);
// 6. 获取内容的字节数组
byte[] bytes = Config.getSerializerAlgorithm().serialize(msg);
// 7. 长度
out.writeInt(bytes.length);
// 8. 写入内容
out.writeBytes(bytes);
outList.add(out);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
int magicNum = in.readInt();
byte version = in.readByte();
byte serializerAlgorithm = in.readByte(); // 0 或 1
byte messageType = in.readByte(); // 0,1,2...
int sequenceId = in.readInt();
in.readByte();
int length = in.readInt();
byte[] bytes = new byte[length];
in.readBytes(bytes, 0, length);
// 找到反序列化算法
Serializer.Algorithm algorithm = Serializer.Algorithm.values()[serializerAlgorithm];
// 确定具体消息类型
Class<? extends Message> messageClass = Message.getMessageClass(messageType);
Message message = algorithm.deserialize(messageClass, bytes);
// log.debug("{}, {}, {}, {}, {}, {}", magicNum, version, serializerType, messageType, sequenceId, length);
// log.debug("{}", message);
out.add(message);
}
}
package cn.itcast.protocol;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
/**
* 通过规定消息的格式,处理半包黏包编码器
*/
public class ProtocolFrameDecoder extends LengthFieldBasedFrameDecoder {
public ProtocolFrameDecoder() {
this(1024, 12, 4, 0, 0);
}
public ProtocolFrameDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip) {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
}
}
服务端
package cn.itcast.server;
import cn.itcast.protocol.MessageCodecSharable;
import cn.itcast.protocol.ProtocolFrameDecoder;
import cn.itcast.server.handler.RpcRequestMessageHandler;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class RpcServer {
public static void main(String[] args) {
NioEventLoopGroup boss = new NioEventLoopGroup();
NioEventLoopGroup worker = new NioEventLoopGroup();
//日志
LoggingHandler LOGGING_HANDLER = new LoggingHandler(LogLevel.DEBUG);
//消息编解码
MessageCodecSharable MESSAGE_CODEC = new MessageCodecSharable();
//针对Rpc请求消息的处理器
RpcRequestMessageHandler RPC_HANDLER = new RpcRequestMessageHandler();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.channel(NioServerSocketChannel.class);
serverBootstrap.group(boss, worker);
serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
//一定加这个处理半包黏包的处理器
ch.pipeline().addLast(new ProtocolFrameDecoder());
ch.pipeline().addLast(LOGGING_HANDLER);
ch.pipeline().addLast(MESSAGE_CODEC);
ch.pipeline().addLast(RPC_HANDLER);
}
});
Channel channel = serverBootstrap.bind(8080).sync().channel();
channel.closeFuture().sync();
} catch (InterruptedException e) {
log.error("server error", e);
} finally {
boss.shutdownGracefully();
worker.shutdownGracefully();
}
}
}
package cn.itcast.server.handler;
import cn.itcast.message.RpcRequestMessage;
import cn.itcast.message.RpcResponseMessage;
import cn.itcast.server.service.ServicesFactory;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
/**
* 处理Rpc请求的handler
* @author niumazlb
* @create 2023-01-08 15:37
*/
public class RpcRequestMessageHandler extends SimpleChannelInboundHandler<RpcRequestMessage> {
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcRequestMessage msg) throws Exception {
String interfaceName = msg.getInterfaceName();
String methodName = msg.getMethodName();
int sequenceId = msg.getSequenceId();
Object[] parameterValue = msg.getParameterValue();
Class[] parameterTypes = msg.getParameterTypes();
RpcResponseMessage response = new RpcResponseMessage();
response.setSequenceId(sequenceId);
//通过反射来方法调用
try {
Object service = ServicesFactory.getService(Class.forName(interfaceName));
Method method = service.getClass().getMethod(methodName,parameterTypes);
Object invoke = method.invoke(service, parameterValue);
response.setReturnValue(invoke);
} catch (Exception e) {
e.printStackTrace();
response.setExceptionValue(new Exception("error:"+e.getCause().getMessage()));
}
ctx.writeAndFlush(response);
}
}
客户端
package cn.itcast.client;
import cn.itcast.client.handler.RpcResponseMessageHandler;
import cn.itcast.message.RpcRequestMessage;
import cn.itcast.protocol.MessageCodecSharable;
import cn.itcast.protocol.ProtocolFrameDecoder;
import cn.itcast.protocol.SequenceIdGenerator;
import cn.itcast.server.service.HelloService;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.DefaultPromise;
import java.lang.reflect.Proxy;
/**
* @author niumazlb
* @create 2023-01-08 15:51
*/
public class RpcClientManager {
public static void main(String[] args) {
HelloService proxyService = getProxyService(HelloService.class);
System.out.println(proxyService.sayHello("zhangsan"));
// System.out.println(proxyService.sayHello("lisi"));
// System.out.println(proxyService.sayHello("wangwu"));
// System.out.println(proxyService.sayHello("zhaoliu"));
}
public static <T> T getProxyService(Class<T> serviceClass) {
ClassLoader loader = serviceClass.getClassLoader();
Class<?>[] interfaces = new Class[]{serviceClass};
Object o = Proxy.newProxyInstance(loader, interfaces, (proxy, method, args) -> {
int sequenceId = SequenceIdGenerator.nextId();
RpcRequestMessage message = new RpcRequestMessage(
sequenceId,
serviceClass.getName(),
method.getName(),
method.getReturnType(),
method.getParameterTypes(),
args
);
getChannel().writeAndFlush(message);
DefaultPromise<Object> promise = new DefaultPromise<>(getChannel().eventLoop());
RpcResponseMessageHandler.PROMISES.put(sequenceId, promise);
promise.await();
if (promise.isSuccess()) {
return promise.getNow();
} else {
throw new RuntimeException(promise.cause());
}
});
return (T) o;
}
private static Channel channel = null;
private static final Object LOCK = new Object();
public static Channel getChannel() {
if (channel != null) {
return channel;
}
synchronized (LOCK) {
if (channel != null) {
return channel;
}
initChannel();
return channel;
}
}
private static void initChannel() {
NioEventLoopGroup group = new NioEventLoopGroup();
LoggingHandler LOGGING_HANDLER = new LoggingHandler(LogLevel.DEBUG);
MessageCodecSharable MESSAGE_CODEC = new MessageCodecSharable();
RpcResponseMessageHandler RPC_HANDLER = new RpcResponseMessageHandler();
Bootstrap bootstrap = new Bootstrap();
bootstrap.channel(NioSocketChannel.class);
bootstrap.group(group);
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(new ProtocolFrameDecoder());
ch.pipeline().addLast(LOGGING_HANDLER);
ch.pipeline().addLast(MESSAGE_CODEC);
ch.pipeline().addLast(RPC_HANDLER);
}
});
try {
channel = bootstrap.connect("localhost", 8080).sync().channel();
channel.closeFuture().addListener(future -> {
group.shutdownGracefully();
});
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
package cn.itcast.client.handler;
import cn.itcast.message.RpcResponseMessage;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
@ChannelHandler.Sharable
public class RpcResponseMessageHandler extends SimpleChannelInboundHandler<RpcResponseMessage> {
// 序号 用来接收结果的 promise 对象
public static final Map<Integer, Promise<Object>> PROMISES = new ConcurrentHashMap<>();
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcResponseMessage msg) throws Exception {
log.debug("{}", msg);
int sequenceId = msg.getSequenceId();
Object returnValue = msg.getReturnValue();
Exception exceptionValue = msg.getExceptionValue();
Promise<Object> promise = PROMISES.remove(sequenceId);
if(promise != null){
if(exceptionValue != null){
promise.setFailure(exceptionValue);
}else {
promise.setSuccess(returnValue);
}
}
}
}
消息
package cn.itcast.message;
import lombok.Data;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@Data
public abstract class Message implements Serializable {
public static Class<?> getMessageClass(int messageType) {
return messageClasses.get(messageType);
}
private int sequenceId;
private int messageType;
public abstract int getMessageType();
/**
* 请求类型 byte 值
*/
public static final int RPC_MESSAGE_TYPE_REQUEST = 101;
/**
* 响应类型 byte 值
*/
public static final int RPC_MESSAGE_TYPE_RESPONSE = 102;
private static final Map<Integer, Class<?>> messageClasses = new HashMap<>();
static {
messageClasses.put(RPC_MESSAGE_TYPE_REQUEST, RpcRequestMessage.class);
messageClasses.put(RPC_MESSAGE_TYPE_RESPONSE, RpcResponseMessage.class);
}
}
package cn.itcast.message;
import lombok.Getter;
import lombok.ToString;
/**
* @author yihang
*/
@Getter
@ToString(callSuper = true)
public class RpcRequestMessage extends Message {
/**
* 调用的接口全限定名,服务端根据它找到实现
*/
private String interfaceName;
/**
* 调用接口中的方法名
*/
private String methodName;
/**
* 方法返回类型
*/
private Class<?> returnType;
/**
* 方法参数类型数组
*/
private Class[] parameterTypes;
/**
* 方法参数值数组
*/
private Object[] parameterValue;
public RpcRequestMessage(int sequenceId, String interfaceName, String methodName, Class<?> returnType, Class[] parameterTypes, Object[] parameterValue) {
super.setSequenceId(sequenceId);
this.interfaceName = interfaceName;
this.methodName = methodName;
this.returnType = returnType;
this.parameterTypes = parameterTypes;
this.parameterValue = parameterValue;
}
@Override
public int getMessageType() {
return RPC_MESSAGE_TYPE_REQUEST;
}
}
package cn.itcast.message;
import lombok.Data;
import lombok.ToString;
/**
* @author yihang
*/
@Data
@ToString(callSuper = true)
public class RpcResponseMessage extends Message {
/**
* 返回值
*/
private Object returnValue;
/**
* 异常值
*/
private Exception exceptionValue;
@Override
public int getMessageType() {
return RPC_MESSAGE_TYPE_RESPONSE;
}
}