一、Netty实现简易RPC
项目结构:
api包:定义需要暴露的服务接口
package com.demo.netty.rpc.api;
public interface RpcHelloService {
public String sayHello(String name);
}
package com.demo.netty.rpc.api;
public interface RpcCalculationService {
/** 加 */
public int add(int a,int b);
/** 减 */
public int sub(int a,int b);
/** 乘 */
public int mult(int a,int b);
/** 除 */
public int div(int a,int b);
}
provider包:定义需要暴露的服务的具体实现
package com.demo.netty.rpc.provider;
import com.demo.netty.rpc.api.RpcHelloService;
public class RpcHelloServiceImpl implements RpcHelloService {
@Override
public String sayHello(String name) {
return "Hello "+name+"!";
}
}
package com.demo.netty.rpc.provider;
import com.demo.netty.rpc.api.RpcCalculationService;
public class RpcCalculationServiceImpl implements RpcCalculationService {
@Override
public int add(int a, int b) {
return a+b;
}
@Override
public int sub(int a, int b) {
return a-b;
}
@Override
public int mult(int a, int b) {
return a*b;
}
@Override
public int div(int a, int b) {
return a/b;
}
}
protocol包:定义自定义协议
package com.demo.netty.rpc.protocol;
import java.io.Serializable;
public class InvokerProtocol implements Serializable {
private String className;//类名
private String methodName;//函数名称
private Class<?>[] parames;//参数类型
private Object[] values;//参数列表
public String getClassName() {
return className;
}
public void setClassName(String className) {
this.className = className;
}
public String getMethodName() {
return methodName;
}
public void setMethodName(String methodName) {
this.methodName = methodName;
}
public Class<?>[] getParames() {
return parames;
}
public void setParames(Class<?>[] parames) {
this.parames = parames;
}
public Object[] getValues() {
return values;
}
public void setValues(Object[] values) {
this.values = values;
}
}
registry包:注册服务并暴露出来
RpcRegistry:
package com.demo.netty.rpc.registry;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
public class RpcRegistry {
private int port;
public RpcRegistry(int port){
this.port = port;
}
public void start(){
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
//自定义协议解码器
/** 入参有5个,分别解释如下
maxFrameLength:框架的最大长度。如果帧的长度大于此值,则将抛出TooLongFrameException。
lengthFieldOffset:长度字段的偏移量:即对应的长度字段在整个消息数据中得位置
lengthFieldLength:长度字段的长度。如:长度字段是int型表示,那么这个值就是4(long型就是8)
lengthAdjustment:要添加到长度字段值的补偿值
initialBytesToStrip:从解码帧中去除的第一个字节数
*/
pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
//自定义协议编码器
pipeline.addLast(new LengthFieldPrepender(4));
//对象参数类型编码器
pipeline.addLast("encoder",new ObjectEncoder());
//对象参数类型解码器
pipeline.addLast("decoder",new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
pipeline.addLast(new RegistryHandler());
}
})
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
ChannelFuture future = b.bind(port).sync();
System.out.println("RPC Registry start listen at " + port );
future.channel().closeFuture().sync();
} catch (Exception e) {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
public static void main(String[] args) throws Exception {
new RpcRegistry(8080).start();
}
}
RegistryHandler:实际处理通信数据时的类。
package com.demo.netty.rpc.registry;
import com.demo.netty.rpc.protocol.InvokerProtocol;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.io.File;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
public class RegistryHandler extends ChannelInboundHandlerAdapter {
//用保存所有可用的服务
public static ConcurrentHashMap<String, Object> registryMap = new ConcurrentHashMap<String,Object>();
//保存所有相关的服务类
private List<String> classNames = new ArrayList<String>();
public RegistryHandler(){
//完成递归扫描
scannerClass("com.demo.netty.rpc.provider");
doRegister();
}
//拿到通信数据后的处理方法
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Object result = new Object();
InvokerProtocol request = (InvokerProtocol)msg;
//当客户端建立连接时,需要从自定义协议中获取信息,拿到具体的服务和实参
//使用反射调用
if(registryMap.containsKey(request.getClassName())){
Object clazz = registryMap.get(request.getClassName());
Method method = clazz.getClass().getMethod(request.getMethodName(), request.getParames());
result = method.invoke(clazz, request.getValues());
}
ctx.write(result);
ctx.flush();
ctx.close();
}
//通信时出异常时的处理方法
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
/*
* 递归扫描
*/
private void scannerClass(String packageName){
URL url = this.getClass().getClassLoader().getResource(packageName.replaceAll("\\.", "/"));
File dir = new File(url.getFile());
for (File file : dir.listFiles()) {
//如果是一个文件夹,继续递归
if(file.isDirectory()){
scannerClass(packageName + "." + file.getName());
}else{
classNames.add(packageName + "." + file.getName().replace(".class", "").trim());
}
}
}
/**
* 完成注册
*/
private void doRegister(){
if(classNames.size() == 0){ return; }
for (String className : classNames) {
try {
Class<?> clazz = Class.forName(className);
Class<?> i = clazz.getInterfaces()[0];
registryMap.put(i.getName(), clazz.newInstance());
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
consumer包:
RpcProxy:动态代理类
package com.demo.netty.rpc.consumer;
import com.demo.netty.rpc.protocol.InvokerProtocol;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import sun.security.jca.GetInstance;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
public class RpcProxy {
public static <T> T create(Class<?> clazz){
//clazz传进来本身就是interface
MethodProxy proxy = new MethodProxy(clazz);
Class<?> [] interfaces = clazz.isInterface() ?
new Class[]{clazz} :
clazz.getInterfaces();
T result = (T) Proxy.newProxyInstance(clazz.getClassLoader(),interfaces,proxy);
return result;
}
private static class MethodProxy implements InvocationHandler {
private Class<?> clazz;
public MethodProxy(Class<?> clazz){
this.clazz = clazz;
}
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
//如果传进来是一个已实现的具体类
if (!clazz.isInterface()) {
try {
return method.invoke(clazz.newInstance(), args);
} catch (Throwable t) {
t.printStackTrace();
}
//如果传进来的是一个接口(核心)
} else {
return rpcInvoke(proxy,method, args);
}
return null;
}
/**
* 实现接口的核心方法
* @param method
* @param args
* @return
*/
public Object rpcInvoke(Object proxy,Method method,Object[] args){
//传输协议封装
InvokerProtocol msg = new InvokerProtocol();
msg.setClassName(this.clazz.getName());
msg.setMethodName(method.getName());
msg.setValues(args);
msg.setParames(method.getParameterTypes());
final RpcProxyHandler consumerHandler = new RpcProxyHandler();
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
//自定义协议解码器
/** 入参有5个,分别解释如下
maxFrameLength:框架的最大长度。如果帧的长度大于此值,则将抛出TooLongFrameException。
lengthFieldOffset:长度字段的偏移量:即对应的长度字段在整个消息数据中得位置
lengthFieldLength:长度字段的长度:如:长度字段是int型表示,那么这个值就是4(long型就是8)
lengthAdjustment:要添加到长度字段值的补偿值
initialBytesToStrip:从解码帧中去除的第一个字节数
*/
pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
//自定义协议编码器
pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
//对象参数类型编码器
pipeline.addLast("encoder", new ObjectEncoder());
//对象参数类型解码器
pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
pipeline.addLast("handler",consumerHandler);
}
});
ChannelFuture future = b.connect("localhost", 8080).sync();
future.channel().writeAndFlush(msg).sync();
future.channel().closeFuture().sync();
} catch(Exception e){
e.printStackTrace();
}finally {
group.shutdownGracefully();
}
return consumerHandler.getResponse();
}
}
}
RpcProxyHandler:Netty进行通信时的处理类
package com.demo.netty.rpc.consumer;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
public class RpcProxyHandler extends ChannelInboundHandlerAdapter {
private Object response;
public Object getResponse() {
return response;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
response = msg;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println("client exception is general");
}
}
RpcConsumer:客户端调用服务端服务
package com.demo.netty.rpc.consumer;
import com.demo.netty.rpc.api.RpcCalculationService;
import com.demo.netty.rpc.api.RpcHelloService;
import com.demo.netty.rpc.provider.RpcHelloServiceImpl;
public class RpcConsumer {
public static void main(String [] args){
RpcHelloService rpcHello = RpcProxy.create(RpcHelloService.class);
System.out.println(rpcHello.sayHello("zhangsan"));
RpcCalculationService service = RpcProxy.create(RpcCalculationService.class);
System.out.println("8 + 2 = " + service.add(8, 2));
System.out.println("8 - 2 = " + service.sub(8, 2));
System.out.println("8 * 2 = " + service.mult(8, 2));
System.out.println("8 / 2 = " + service.div(8, 2));
}
}