基于netty实现的简单RPC调用

1、RPC的架构

在这里插入图片描述
从上图可看出RPC主要分为三个部分:
(1)服务提供者(RPC Server),运行在服务器端,提供服务接口定义与服务实现类。

(2)服务中心(Registry),运行在服务器端,负责将本地服务发布成远程服务,管理远程服务,提供给服务消费者使用。

(3)服务消费者(RPC Client),运行在客户端,通过远程代理对象调用远程服务。

2、基于netty实现RPC的思路

根据上面的rpc结构图,如果我们想实现一个rpc调用框架,则必须实现上述三部分。

2.1、服务提供者:

1、定义客户端调用的API(服务接口)。
2、自定义一套协议,即客户端和服务端交互的数据。
3、编写服务提供方,即API的具体实现。

2.2 服务中心

基于netty编写服务端程序作为注册中心,功能包括:
1、自定义协议的编解码。
2、处理客户端请求的handler,该handler的主要功能包括:
2.1、扫描对外服务实现类所在的包,将服务的提供方对象实例保存到容器中。
2.2、解析客户端传来的协议,根据协议类容调用容器中服务提供者的方法,并将结果返回。

2.3服务消费者

基于netty编写服务端程序,功能包括:
1、将自定义协议发送给服务中心,通过自定义handler接受服务端的响应。

3、项目包结构

在这里插入图片描述

3、相关代码如下

3.1服务提供者

package com.syx.rpc.api;
public interface IRpcHelloService {
    String hello(String name);
}

package com.syx.rpc.api;
public interface IRpcService {

    int add(int a,int b);

    int sub(int a,int b);

    int mult(int a,int b);
}
package com.syx.rpc.protocol;
import lombok.Data;
import java.io.Serializable;

/**
 * 客户端服务端交互的协议
 */
@Data
public class InvokerProtocol implements Serializable {

    //类名
    private String className;

    //方法名
    private String methodName;

    //参数类型
    private Class<?>[] paramTypes;

    //参数列表
    private Object[] params;
}
package com.syx.rpc.provider;

import com.syx.rpc.api.IRpcHelloService;

public class RpcHelloServiceImpl implements IRpcHelloService {
    @Override
    public String hello(String name) {
        return "hello "+name;
    }
}

package com.syx.rpc.provider;

import com.syx.rpc.api.IRpcService;

public class RpcServiceImpl implements IRpcService {
    @Override
    public int add(int a, int b) {
        return a+b;
    }

    @Override
    public int sub(int a, int b) {
        return a-b;
    }

    @Override
    public int mult(int a, int b) {
        return a*b;
    }
}

3.2服务中心

package com.syx.rpc.register;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolver;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
/**
 * Rpc注册中心相当于一个服务器
 */
public class RpcRegister {

    private int port;

    public RpcRegister(int port) {
        this.port = port;
    }


    //启动netty服务器
    public void start() {
        NioEventLoopGroup boss = new NioEventLoopGroup();

        NioEventLoopGroup work = new NioEventLoopGroup();

        try {
            ServerBootstrap serverBootstrap = new ServerBootstrap();


            ServerBootstrap bootstrap = serverBootstrap.group(boss, work)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {

                            ChannelPipeline pipeline = ch.pipeline();


                            //自定义协议解码器
                            pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0,
                                    4, 0, 4));

                            //自定义协议编码器
                            pipeline.addLast(new LengthFieldPrepender(4));

                            //对象参数类型解码器
                            pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));


                            //对象参数类型编码器
                            pipeline.addLast("encoder", new ObjectEncoder());

                            pipeline.addLast(new RegisterHandler("com.syx.rpc.provider"));


                        }
                    }).option(ChannelOption.SO_BACKLOG, 128)
                    .childOption(ChannelOption.SO_KEEPALIVE, true);

            ChannelFuture future = bootstrap.bind(port).sync();

            System.out.println("RPC Register Listening in port " + port);
            future.channel().closeFuture().sync();

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            boss.shutdownGracefully();
            work.shutdownGracefully();
        }
    }

    public static void main(String[] args) {
        new RpcRegister(9999).start();
    }
}

package com.syx.rpc.register;

import com.syx.rpc.protocol.InvokerProtocol;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

import java.io.File;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 1、扫描对外服务实现类所在的包,将服务的提供方对象实例保存到容器中
 * 2、解析客户端传来的协议,根据协议类容调用容器中服务提供者的方法,并将结果返回
 */
public class RegisterHandler extends ChannelInboundHandlerAdapter {

    private static ConcurrentHashMap<String,Object> registerMap = new ConcurrentHashMap<>();

    public RegisterHandler(String providerPackage) {

        List<String> classNames = scanClass(providerPackage);

        doRegister(classNames);

    }

    /**
     * 扫描指定包下的类
     * @param packageName
     */
    private  static List<String> scanClass(String packageName){
        List<String> className = new ArrayList<>();
        URL url = Thread.currentThread().getContextClassLoader().getResource(packageName.replaceAll("\\.", "/"));
        File  file = new File(url.getFile());
        if(!file.exists()){
            return className;
        }
        Stack<File> stack = new Stack<>();
        stack.push(file);
        while (!stack.isEmpty()){
            File pop = stack.pop();
            for (File f : pop.listFiles()) {
                if(f.isDirectory()){
                    stack.push(f);
                }else {
                    String prefix = packageName.split("\\.")[0];
                    String suffix = ".class";
                    String path = f.getPath();
                    String s = path.substring(path.indexOf(prefix), path.indexOf(suffix)).replaceAll("\\\\", "\\.");
                    className.add(s);
                }
            }
        }
        return className;
    }


    /**
     * 注册实现类
     * @param className
     */
    private static void doRegister(List<String> className){
        if(className.size()==0)return;
        for (String s : className) {
            try{
                Class<?> clazz = Class.forName(s);
                Class<?>[]  c = clazz.getInterfaces();
                if(c.length>0){
                    registerMap.put(c[0].getName(),clazz.newInstance());
                }
            }catch (Exception e){
                e.printStackTrace();
            }

        }
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {

        //读取请求的信息
        InvokerProtocol request = (InvokerProtocol) msg;

        Object result = new Object();
        if(registerMap.containsKey(request.getClassName())){

            //本地协议中约定的接口示例对象
            Object instance = registerMap.get(request.getClassName());

            Method method = instance.getClass().getDeclaredMethod(request.getMethodName(), request.getParamTypes());

            //方法执行的返回值
             result = method.invoke(instance, request.getParams());

        }

        //将返回值写回客户端channel
        ctx.write(result);
        ctx.flush();
        ctx.close();
    }

    public static void main(String[] args) {
        List<String> classNames = scanClass("com.syx.rpc.provider");
        classNames.forEach(System.out::println);
        System.out.println("***********");
        doRegister(classNames);
        registerMap.forEach((k,v)->{
            System.out.println(k+"  "+v);
        });

    }
}

3.3消费者

package com.syx.rpc.consumer;

import com.syx.rpc.protocol.InvokerProtocol;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

public class RpcProxy {

    public static <T> T create(Class<?> clazz){
        MethodProxy methodProxy = new MethodProxy(clazz);

        Class<?> [] interfaces = clazz.isInterface()?new Class[]{clazz}:clazz.getInterfaces();

        T o = (T) Proxy.newProxyInstance(clazz.getClassLoader(), interfaces, methodProxy);

        return o;
    }


    private static class MethodProxy implements InvocationHandler{

        private Class<?> clazz;

        private MethodProxy(Class<?> clazz) {
            this.clazz = clazz;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            if(Object.class.equals(method.getDeclaringClass())){
                return method.invoke(this,args);
            }else {
                return rpcInvoker(proxy,method,args);
            }
        }

        private Object rpcInvoker(Object proxy, Method method, Object[] args) {

            InvokerProtocol request = new InvokerProtocol();

            request.setMethodName(method.getName());

            request.setClassName(clazz.getName());

            request.setParamTypes(method.getParameterTypes());

            request.setParams(args);

            NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup();
            RpcProxyHandler rpcProxyHandler = new RpcProxyHandler();
            try {

                Bootstrap bootstrap = new Bootstrap();

                bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class)
                        .option(ChannelOption.TCP_NODELAY,true)
                        .handler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel ch) throws Exception {

                                ChannelPipeline pipeline = ch.pipeline();

                                //自定义协议解码器
                                pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0,
                                        4, 0, 4));

                                //自定义协议编码器
                                pipeline.addLast(new LengthFieldPrepender(4));


                                //对象参数类型解码器
                                pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));


                                //对象参数类型编码器
                                pipeline.addLast("encoder", new ObjectEncoder());

                                pipeline.addLast(rpcProxyHandler);
                            }
                        });

                ChannelFuture sync = bootstrap.connect("127.0.0.1", 9999).sync();

                sync.channel().writeAndFlush(request).sync();
                sync.channel().closeFuture().sync();
            }catch (Exception e){
                e.printStackTrace();
            }finally {
                eventLoopGroup.shutdownGracefully();
            }
            return rpcProxyHandler.getResponse();
        }
    }
}
package com.syx.rpc.consumer;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.concurrent.EventExecutorGroup;

public class RpcProxyHandler extends ChannelInboundHandlerAdapter {

    private Object response;

    public Object getResponse() {
        return response;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        setResponse(msg);
    }

    public void setResponse(Object response) {
        this.response = response;
    }
}
package com.syx.rpc.consumer;

import com.syx.rpc.api.IRpcHelloService;
import com.syx.rpc.api.IRpcService;
import com.syx.rpc.provider.RpcServiceImpl;

public class RpcConsumer {

    public static void main(String[] args) {
        IRpcHelloService helloService = RpcProxy.create(IRpcHelloService.class);


        String whh = helloService.hello("whh");

        System.out.println(whh);

        System.out.println("rpc调用");
        long l1 = System.currentTimeMillis();
        IRpcService rpcService = RpcProxy.create(IRpcService.class);
        System.out.println(rpcService.add(1, 2));
        long l2 = System.currentTimeMillis();
        System.out.println( l2- l1);

        System.out.println("本地调用");
        long l3 = System.currentTimeMillis();
        RpcServiceImpl rpcService1 = new RpcServiceImpl();
        System.out.println(rpcService1.add(1, 2));
        System.out.println(System.currentTimeMillis() - l3);
    }
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值