【RPC】自己动手实现RPC框架

参考视频:https://www.imooc.com/learn/1158

1 什么是RPC

RPC (Remote Procedure Call)
主机A上的进程a想调用主机B上的进程b,就叫远程过程调用。由于不在同一个内存空间,不能直接调用,需要通过网络来表达调用的语义和传达调用的数据。

在研究RPC之前,我们先看看本地过程调用是什么样的

作者:洪春涛
链接:https://www.zhihu.com/question/25536695/answer/221638079

1.1 本地过程调用

本地过程调用RPC就是要像调用本地的函数一样去调远程函数。在研究RPC前,我们先看看本地调用是怎么调的。假设我们要调用函数Multiply来计算lvalue * rvalue的结果:

int Multiply(int l, int r) {
   int y = l * r;
   return y;
}

int lvalue = 10;
int rvalue = 20;
int l_times_r = Multiply(lvalue, rvalue);

那么在第8行时,我们实际上执行了以下操作:

  1. 将 lvalue 和 rvalue 的值压栈
  2. 进入Multiply函数,取出栈中的值10 和 20,将其赋予 l 和 r
  3. 执行第2行代码,计算 l * r ,并将结果存在 y
  4. 将 y 的值压栈,然后从Multiply返回
  5. 第8行,从栈中取出返回值 200 ,并赋值给 l_times_r

以上5步就是执行本地调用的过程。(20190116注:以上步骤只是为了说明原理。事实上编译器经常会做优化,对于参数和返回值少的情况会直接将其存放在寄存器,而不需要压栈弹栈的过程,甚至都不需要调用call,而直接做inline操作。仅就原理来说,这5步是没有问题的。)

1.2 远程过程调用

远程过程调用带来的新问题是,在远程调用时,我们需要执行的函数体是在远程的机器上的,也就是说,Multiply是在另一个进程中执行的。这就带来了几个新问题:

  1. Call ID映射。我们怎么告诉远程机器我们要调用Multiply,而不是Add或者FooBar呢?在本地调用中,函数体是直接通过函数指针来指定的,我们调用Multiply,编译器就自动帮我们调用它相应的函数指针。但是在远程调用中,函数指针是不行的,因为两个进程的地址空间是完全不一样的。所以,在RPC中,所有的函数都必须有自己的一个ID。这个ID在所有进程中都是唯一确定的。客户端在做远程过程调用时,必须附上这个ID。然后我们还需要在客户端和服务端分别维护一个 {函数 <–> Call ID} 的对应表。两者的表不一定需要完全相同,但相同的函数对应的Call ID必须相同。当客户端需要进行远程调用时,它就查一下这个表,找出相应的Call ID,然后把它传给服务端,服务端也通过查表,来确定客户端需要调用的函数,然后执行相应函数的代码。
  2. 序列化和反序列化。客户端怎么把参数值传给远程的函数呢?在本地调用中,我们只需要把参数压到栈里,然后让函数自己去栈里读就行。但是在远程过程调用时,客户端跟服务端是不同的进程,不能通过内存来传递参数。甚至有时候客户端和服务端使用的都不是同一种语言(比如服务端用C++,客户端用Java或者Python)。这时候就需要客户端把参数先转成一个字节流,传给服务端后,再把字节流转成自己能读取的格式。这个过程叫序列化和反序列化。同理,从服务端返回的值也需要序列化反序列化的过程。
  3. 网络传输。远程调用往往用在网络上,客户端和服务端是通过网络连接的。所有的数据都需要通过网络传输,因此就需要有一个网络传输层。网络传输层需要把Call ID和序列化后的参数字节流传给服务端,然后再把序列化后的调用结果传回客户端。只要能完成这两者的,都可以作为传输层使用。因此,它所使用的协议其实是不限的,能完成传输就行。尽管大部分RPC框架都使用TCP协议,但其实UDP也可以,而gRPC干脆就用了HTTP2。Java的Netty也属于这层的东西。有了这三个机制,就能实现RPC了,具体过程如下:
// Client端 
//    int l_times_r = Call(ServerAddr, Multiply, lvalue, rvalue)
将这个调用映射为Call ID。这里假设用最简单的字符串当Call ID的方法
将Call ID,lvalue和rvalue序列化。可以直接将它们的值以二进制形式打包
把2中得到的数据包发送给ServerAddr,这需要使用网络传输层
等待服务器返回结果
如果服务器调用成功,那么就将结果反序列化,并赋给l_times_r

// Server端
在本地维护一个Call ID到函数指针的映射call_id_map,可以用std::map<std::string, std::function<>>
等待请求
得到一个请求后,将其数据包反序列化,得到Call ID
通过在call_id_map中查找,得到相应的函数指针
将lvalue和rvalue反序列化后,在本地调用Multiply函数,得到结果
将结果序列化后通过网络返回给Client

所以要实现一个RPC框架,其实只需要按以上流程实现就基本完成了。

2 理论基础

2.1 跨进程数据交换

跨进程交互形式:RESTful、WebService、HTTP、基于DB做数据交换、基于MQ(Message Queue消息队列)做数据交换,以及RPC。

2.1.1 依赖中间件做数据交换

在这里插入图片描述
系统A放数据到中间件,系统B从中间件中取数据。

2.1.2 直接交互

在这里插入图片描述
这种交互方式,两个系统是同步执行的。服务端的速度会直接影响到客户端,这种情况下对响应速度的要求是非常高的,比用中间件做数据交换的情况要高得多。在中间件交互的情况下,上游系统把数据放在中间件里就继续执行自己的任务了,数据可以较长时间存储在中间件,下游系统想什么取数据就什么时候取数据。而直接交互的情况下,客户端会一直等待服务端返回数据。

名词说明:
在RPC中
Server : Provider、服务提供者
Client : Consumer、服务消费者
Stub:存根、服务描述

这么多直接交互的方式,相比于其他方式,RPC的优点是什么?RPC可以像调用本地方法一样调用远程方法。

2.2 现有RPC框架

在这里插入图片描述

2.3 RPC整体架构

在这里插入图片描述

  1. Server把它可以提供的服务以及地址在注册中心注册
  2. Client订阅注册中心,关注它需要的服务
  3. 如果Server的服务发生改变,Server会再次注册到注册中心,注册中心通知Client服务发生改变
  4. Client已经有了Server的服务信息和地址,就可以向Server发起调用【整个RPC里面最关键的一步】

调用过程描述:
在这里插入图片描述

3 自己动手实现RPC框架

在这里插入图片描述
RPC框架主要由5个主要的模块组成.

  1. 协议模块
    描述Server与Client之间的通信协议。
    Request类:需要请求Server的哪个服务,请求带的参数
    Reponse类:Server响应给Client的信息,如是否成功、返回值等
    ServiceDescriptor类:一个服务的描述信息
  2. 序列化模块
    对象与二进制之间的互转
    Encoder类:把对象编码成二进制数据
    Decoder类:把二进制数据反编码成对象
    互转是基于JSON实现的
  3. 网络传输模块
    基于HTTP实现网络传输
  4. 服务端模块
    ServiceManager类:Server把服务注册到这里
    ServiceInstance类:服务的具体实现类
  5. 客户端模块
    Remotelnvoker类:RpcClient通过这个类和Server交互,交互信息是通过Request和Response封装的
    TransportSelector类:Client连接时可以选择一个client与一个server连接,也可以与多个server连接

3.0 创建工程

3.0.1 写依赖

    <dependencyManagement>
        <dependencies>
<!--            帮助IO开发-->
            <dependency>
                <groupId>commons-io</groupId>
                <artifactId>commons-io</artifactId>
                <version>${commons.version}</version>
            </dependency>
<!--            为servlet提供运行时环境-->
            <dependency>
                <groupId>org.eclipse.jetty</groupId>
                <artifactId>jetty-servlet</artifactId>
                <version>${jetty.version}</version>
            </dependency>
<!--            序列化-->
            <dependency>
                <groupId>com.alibaba</groupId>
                <artifactId>fastjson</artifactId>
                <version>${fastjson.version}</version>
            </dependency>

        </dependencies>

    </dependencyManagement>

    <dependencies>
<!--        单元测试-->
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>${junit.version}</version>
        </dependency>
<!--        注解简化Java代码-->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>${lombok.version}</version>
        </dependency>
<!--        日志门面-->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
<!--        日志实现-->
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>${logback.version}</version>
        </dependency>

    </dependencies>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <java.version>1.8</java.version>
        <maven.version>3.8.1</maven.version>
        <commons.version>2.5</commons.version>
        <jetty.version>9.2.28.v20190418</jetty.version>
        <fastjson.version>1.2.50</fastjson.version>
        <lombok.version>1.18.8</lombok.version>
        <junit.version>4.12</junit.version>
        <slf4j.version>1.7.26</slf4j.version>
        <logback.version>1.2.3</logback.version>
    </properties>

3.0.2 Lombok配置

在这里插入图片描述
在这里插入图片描述

3.1 协议模块(proto)

3.1.1 网络传输端点

/**
 * 表示网络传输的一个端点
 */
@Data//给变量创造get/set/toString方法
@AllArgsConstructor//带所有字段的构造方法
public class Peer {
    private String host;
    private String port;
}

3.1.2 服务

/**
 * 表示服务
 */

@Data
@AllArgsConstructor
@NoArgsConstructor//不带字段的默认构造方法
public class ServiceDescriptor {
    private String clazz;//类名
    private String method;//方法名
    private String returnType;//返回值类型
    private String[] parameterTypes;//参数类型
}

3.1.3 请求

/**
 * 表示RPC的一个请求
 */
@Data
public class Request {
    private ServiceDescriptor service;//要请求什么服务?
    private Object[] parameters;//请求带的参数
}

3.1.4 响应

/**
 * 表示RPC的响应
 */
@Data
public class Response {
    private int code=0;//返回一个code表示成功与否,0-成功;非0-失败
    private String message="ok";//如果错误,返回错误原因
    private Object data;//具体返回的数据
}

3.2 公用模块(common)

3.2.1 写工具类方法(用static修饰)

package com.wx.gkrpc.utils;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;

/**
 * 反射工具类
 * 工具类方法一般用static修饰,这样不用实例化也可以调用方法
 */
public class ReflectionUtils {
    /**
     * 根据class创建对象
     * @param clazz 待创建对象的类
     * @param <T> 对象类型
     * @return 创建好的对象
     */
    public static <T> T newInstance(Class<T> clazz){
        try {
            return clazz.newInstance();
        }catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * 获取某个类的公共方法
     * @param clazz 类名
     * @return 该类的公共方法
     */
    public static Method[] getPublicMethods(Class clazz){
        Method[] methods = clazz.getDeclaredMethods();
        List<Method> pmethods = new ArrayList<>();

        for(Method m : methods){
            if(Modifier.isPublic(m.getModifiers())){
                pmethods.add(m);
            }
        }
        return pmethods.toArray(new Method[0]);//将ArrayList对象转化为Method类型的数组
    }

    /**
     * 调用指定对象的指定方法
     * @param obj 被调用方法的对象
     * @param method 被调用的方法
     * @param args 该方法的参数
     * @return 返回结果
     */
    public static Object invoke(Object obj,
                                Method method,
                                Object... args){
        try {
            return method.invoke(obj, args);
        }  catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }
}

3.2.2 测试

在准备被测试的类中任意空白处按:ctrl+shift+t,自动生成测试类

package com.wx.gkrpc.utils;

public class TestClass {
    private String a(){
        return "a";
    }
    public String b(){
        return "b";
    }

    protected String c(){
        return "c";
    }
}

package com.wx.gkrpc.utils;

import org.junit.Test;

import java.lang.reflect.Method;

import static org.junit.Assert.*;

public class ReflectionUtilsTest {

    @Test
    public void newInstance() {
        TestClass t = ReflectionUtils.newInstance(TestClass.class);
        //不空:继续执行;为空:抛出异常
        assertNotNull(t);
    }

    @Test
    public void getPublicMethods() {
        Method[] methods = ReflectionUtils.getPublicMethods(TestClass.class);
        //1. 如果两者一致, 程序继续往下运行. 2. 如果两者不一致, 中断测试方法, 抛出异常信息 AssertionFailedError .
        assertEquals(1, methods.length);

        String mName = methods[0].getName();
        assertEquals("b",mName);
    }

    @Test
    public void invoke() {
        Method[] methods = ReflectionUtils.getPublicMethods(TestClass.class);
        Method b = methods[0];

        TestClass t = new TestClass();
        Object obj = ReflectionUtils.invoke(t,b);
        assertEquals("b",obj);
    }
}

3.3 序列化模块(codec)

Encoder

package com.wx.gkrpc.codec;

/**
 * 序列化,将对象转成二进制
 */
public interface Encoder {
    byte[] encode(Object obj);
}

Decoder

/**
 * 反序列化,将二进制转成对象
 */
public interface Decoder {
    <T> T decode(byte[] bytes,Class<T> clazz);
}

JSONEncoder

/**
 * 基于JSON的序列化实现
 */
public class JSONEncoder implements Encoder{
    @Override
    public byte[] encode(Object obj) {
        return JSON.toJSONBytes(obj);
    }
}

JSONDecoder

/**
 * 基于JSON的反序列化实现
 */
public class JSONDecoder implements Decoder{
    @Override
    public <T> T decode(byte[] bytes, Class<T> clazz) {
        return JSON.parseObject(bytes, clazz);
    }
}

测试

public class JSONEncoderTest {

    @Test
    public void encode() {
        JSONEncoder encoder = new JSONEncoder();

        TestBean bean = new TestBean();
        bean.setName("wx");
        bean.setAge(18);

        byte[] bytes = encoder.encode(bean);

        assertNotNull(bytes);
    }
}
public class JSONDecoderTest {

    @Test
    public void decode() {
        JSONEncoder encoder = new JSONEncoder();

        TestBean bean = new TestBean();
        bean.setName("wx");
        bean.setAge(18);

        byte[] bytes = encoder.encode(bean);

        JSONDecoder decoder = new JSONDecoder();
        TestBean bean2 = decoder.decode(bytes, TestBean.class);
        assertEquals("wx",bean2.getName());
    }
}

3.4 网络通信模块(transport)

Client接口

/**
 * 1.创建连接
 * 2.发送数据并且等待响应
 * 3.关闭连接
 */
public interface TransportClient {
    void connect(Peer peer);

    InputStream write(InputStream data);

    void close();

}

处理请求的handler

/**
 * 处理网络请求的handler
 */
public interface RequestHandler {
    void onRequest(InputStream recive, OutputStream toResp);
}

Service接口

/**
 * 1.启动,监听窗口
 * 2.接收请求
 * 3.关闭监听
 */
public interface TransportServer {
    void init(int port, RequestHandler handler);

    void start();

    void stop() ;
}

基于Http的Client实现类

public class HttpTransportClient implements TransportClient{
    private String url;

    @Override
    public void connect(Peer peer) {
        this.url = "http://" + peer.getHost()+":"+peer.getPort();

    }

    @Override
    public InputStream write(InputStream data) {
        try {
            HttpURLConnection httpConn = (HttpURLConnection) new URL(url).openConnection();
            httpConn.setDoOutput(true);//有输出
            httpConn.setDoInput(true);//读数据
            httpConn.setUseCaches(false);//不用缓存
            httpConn.setRequestMethod("POST");//请求方法为post

            httpConn.connect();//连接

            IOUtils.copy(data, httpConn.getOutputStream());//数据发送给server

            int resultCode = httpConn.getResponseCode();

            if(resultCode == HttpURLConnection.HTTP_OK){
                return httpConn.getInputStream();//获取响应
            }else{
                return httpConn.getErrorStream();
            }
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public void close() {

    }
}

基于Http的Server实现类

@Slf4j
public class HttpTransportServer implements TransportServer{
    private RequestHandler handler;
    private Server server;

    @Override
    public void init(int port, RequestHandler handler) {
        this.handler=handler;
        this.server = new Server(port);

        //servlet接收请求
        ServletContextHandler ctx = new ServletContextHandler();
        server.setHandler(ctx);

        //ServletHolder:网络请求抽象
        ServletHolder holder = new ServletHolder(new RequestServlet());
        ctx.addServlet(holder, "/*");
    }

    @Override
    public void start() {
        try {
            server.start();
            //让server一直挂起不要立即返回
            server.join();
        } catch (Exception e) {
//            e.printStackTrace();
            log.error(e.getMessage(), e);
        }
    }

    @Override
    public void stop() {
        try {
            server.stop();
        } catch (Exception e) {
            e.printStackTrace();
            log.error(e.getMessage(), e);
        }

    }

    class RequestServlet extends HttpServlet {
        @Override
        protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
            log.info("client connect");
            ServletInputStream in = req.getInputStream();
            ServletOutputStream out = resp.getOutputStream();

            if (handler != null) {
                handler.onRequest(in, out);
            }

            out.flush();
        }
    }
}

3.5 服务端模块(server)

3.5.0 引入模块

    <dependencies>
        <!--协议模块-->
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-proto</artifactId>
            <version>${project.version}</version>
        </dependency>
        <!--网络模块-->
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-transport</artifactId>
            <version>${project.version}</version>
        </dependency>
        <!--序列化模块-->
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-codec</artifactId>
            <version>${project.version}</version>
        </dependency>
        <!--共同模块-->
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-common</artifactId>
            <version>${project.version}</version>
        </dependency>
        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
        </dependency>
    </dependencies>

3.5.1 server配置类

/**
 * server配置类
 */
 @Data
public class RpcServerConfig {
    //网络协议
    private Class<? extends TransportServer> transportClass = HttpTransportServer.class;
    //序列化
    private Class<? extends Encoder> encoderClass = JSONEncoder.class;
    private Class<? extends Decoder> decoderClass = JSONDecoder.class;

    private int port = 3000;
}

3.5.2 Server具体实例

/**
 * 表示一个具体的服务
 */
@Data
@AllArgsConstructor
public class ServiceInstance {
    private Object target;//这个服务是由哪个对象提供的
    private Method method;//该对象的哪个方法
}

3.5.3 Server管理类

/**
 * 管理rpc暴露的服务
 * 1.注册服务 - Map
 * 2.查找服务
 */
@Slf4j
public class ServiceManager {
    //key是服务的描述,value是服务实例
    private Map<ServiceDescriptor,ServiceInstance> services;

    public ServiceManager() {
        this.services = new ConcurrentHashMap<>();
    }

    //注册服务
    public <T> void register(Class<T> interfaceClass, T bean) {
        Method[] methods = ReflectionUtils.getPublicMethods(interfaceClass);
        for (Method method : methods) {
            ServiceInstance serviceInstance = new ServiceInstance(bean, method);
            ServiceDescriptor sdp = ServiceDescriptor.from(interfaceClass, method);
            services.put(sdp, serviceInstance);
            log.info("register service: {}:{}", sdp.getClazz(), sdp.getMethod());
        }
    }

    //查找服务
    public ServiceInstance lookup(Request request) {
        ServiceDescriptor sdp = request.getService();
        return services.get(sdp);
    }

}

ServiceDescriptor.from方法

public static ServiceDescriptor from(Class clazz, Method method){
        ServiceDescriptor sdp = new ServiceDescriptor();
        sdp.setClazz(clazz.getName());
        sdp.setMethod(method.getName());
        sdp.setReturnType(method.getReturnType().getName());

        Class[] parameterClasses = method.getParameterTypes();
        String[] parameterTypes = new String[parameterClasses.length];

        for(int i=0; i< parameterClasses.length; i++){
            parameterTypes[i] = parameterClasses[i].getName();
        }

        return sdp;
    }

因为Map的key是我们自己定义的类ServiceDescriptor,Map在get的时候用的是该类的equal方法判断的,所以要重写equal方法
ServiceDescriptor类添加如下方法:

    @Override
    public int hashCode() {
        return toString().hashCode();
    }

    @Override
    public boolean equals(Object obj) {
        if(this==obj)return true;
        else if(obj==null || getClass()!= obj.getClass()) return false;

        ServiceDescriptor that = (ServiceDescriptor)obj;
        return this.toString().equals(that.toString());
    }

    @Override
    public String toString() {
        return "clazz="+clazz+
                ",method="+method+
                ",returnType="+returnType+
                ",parameterTypes="+Arrays.toString(parameterTypes);
    }

测试管理类

public interface TestInterface {
    void hello();
}
public class TestClass implements TestInterface{
    @Override
    public void hello() {}
}
public class ServiceManagerTest {
    ServiceManager sm;

    @Before
    public void init(){
        sm = new ServiceManager();

        TestInterface bean = new TestClass();

        sm.register(TestInterface.class,bean);
    }

    @Test
    public void register() {
        TestInterface bean = new TestClass();

        sm.register(TestInterface.class,bean);
    }

    @Test
    public void lookup() {
        Method method = ReflectionUtils.getPublicMethods(TestInterface.class)[0];
        ServiceDescriptor sdp = ServiceDescriptor.from(TestInterface.class, method);

        Request request = new Request();
        request.setService(sdp);

        ServiceInstance sis = sm.lookup(request);
        assertNotNull(sis);
        assertEquals(sis.getMethod().getName(),"hello");
    }
}

3.5.4 Server调用具体服务

/**
 * 调用具体服务
 */
public class ServiceInvoker {
    public Object invoke(ServiceInstance service, Request request) {
        return ReflectionUtils.invoke(service.getTarget(), service.getMethod(), request.getParameters());
    }
}

3.5.5 RpcServer

@Data
@Slf4j
public class RpcServer {
    private RpcServerConfig config;
    private TransportServer net;
    private Encoder encoder;
    private Decoder decoder;
    private ServiceManager serviceManager;
    private ServiceInvoker serviceInvoker;
    private RequestHandler handler = new RequestHandler() {

        @Override
        public void onRequest(InputStream in, OutputStream out) {
            Response response = new Response();
            try {
                byte[] bytes = IOUtils.readFully(in, in.available(), true);
                Request request = decoder.decode(bytes, Request.class);
                ServiceInstance instance = serviceManager.lookup(request);
                Object data = serviceInvoker.invoke(instance, request);
                response.setData(data);
            } catch (Exception e) {
                e.printStackTrace();
                log.error(e.getMessage(), e);
                response.setCode(-1);
                response.setMessage("RpcServer error: " + e.getMessage());
            } finally {
                byte[] bytes = encoder.encode(response);
                try {
                    out.write(bytes);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    };

    public RpcServer() {
        this(new RpcServerConfig());
    }

    public RpcServer(RpcServerConfig config) {
        this.config = config;
        //net
        this.net = ReflectionUtils.newInstance(config.getTransportClass());
        this.net.init(config.getPort(), handler);
        //encode
        this.encoder = ReflectionUtils.newInstance(config.getEncoderClass());
        //decode
        this.decoder = ReflectionUtils.newInstance(config.getDecoderClass());
        //service
        this.serviceManager = new ServiceManager();
        this.serviceInvoker = new ServiceInvoker();
    }

    //注册服务
    public <T> void register(Class<T> interfaceClass, T bean) {
        serviceManager.register(interfaceClass, bean);
    }

    public void start() {
        this.net.start();
    }

    public void stop() {
        this.net.stop();
    }
}

3.6 客户端模块(client)

3.6.0 添加依赖

    <dependencies>
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-proto</artifactId>
            <version>${project.version}</version>
        </dependency>
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-codec</artifactId>
            <version>${project.version}</version>
        </dependency>
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-transport</artifactId>
            <version>${project.version}</version>
        </dependency>
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-common</artifactId>
            <version>${project.version}</version>
        </dependency>
        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
        </dependency>
    </dependencies>

3.6.1 TransportSelector

/**
 * 选择哪个Server去连接
 */
public interface TransportSelector {

    /**
     * 初始化selector
     * @param peers 可以连接的server端信息
     * @param count client与server建立多少个连接
     * @param clazz  client实现class
     */
    void init(List<Peer> peers, int count, Class<? extends TransportClient> clazz);

    //选择一个transport与server交互
    TransportClient select();

    //释放用完的client
    void release(TransportClient client);

    void close();
}

3.6.2 实现接口类RandomTransportSelector

@Slf4j
public class RandomTransportSelector implements TransportSelector {

    //存储所有已经连接好的client
    private List<TransportClient> clients;

    public RandomTransportSelector() {
        clients = new ArrayList<>();
    }

    @Override
    public synchronized void init(List<Peer> peers, int count, Class<? extends TransportClient> clazz) {
        //count必须大于等于1
        count = Math.max(count, 1);
        for (Peer peer : peers) {
            for (int i = 0; i < count; i++) {
                TransportClient client = ReflectionUtils.newInstance(clazz);
                client.connect(peer);//和Server连接
                clients.add(client);
            }
            log.info("connect server:{}", peer);
        }
    }

    @Override
    public synchronized TransportClient select() {
        int i = new Random().nextInt(clients.size());
        return clients.remove(i);
    }

    @Override
    public synchronized void release(TransportClient client) {
        clients.add(client);
    }

    @Override
    public synchronized void close() {
        for (TransportClient client : clients) {
            client.close();
        }
        clients.clear();
    }
}

3.6.3 配置类

@Data
public class RpcClientConfig {
    //网络连接
    private Class<? extends TransportClient> transportClass = HttpTransportClient.class;
    //序列化与反序列化的信息
    private Class<? extends Encoder> encoderClass = JSONEncoder.class;
    private Class<? extends Decoder> decoderClass = JSONDecoder.class;
    //路由选择的策略信息,默认随机策略
    private Class<? extends TransportSelector> selectorClass = RandomTransportSelector.class;
    //每一个Server可以建立几个连接,默认1个
    private int connectCount = 1;
    //可以连哪些网络端点,默认为“127.0.0.1”、3000
    private List<Peer> servers = Arrays.asList(new Peer("127.0.0.1", 3000));
}

3.6.4 RpcClient

public class RpcClient {
    private RpcClientConfig config;//配置
    private Encoder encoder;//序列化
    private Decoder decoder;//反序列化
    private TransportSelector selector;//选择器

    public RpcClient() {//无参构造方法
        this(new RpcClientConfig());
    }

    public RpcClient(RpcClientConfig config) {//基于配置的构造方法
        this.config = config;

        this.encoder = ReflectionUtils.newInstance(this.config.getEncoderClass());
        this.decoder = ReflectionUtils.newInstance(this.config.getDecoderClass());
        this.selector = ReflectionUtils.newInstance(this.config.getSelectorClass());

        this.selector.init(this.config.getServers(), this.config.getConnectCount(), this.config.getTransportClass());
    }

    //获取接口的代理对象
    public <T> T getProxy(Class clazz) {
        return (T) Proxy.newProxyInstance(getClass().getClassLoader(), new Class[]{clazz}, new RemoteInvoker(clazz, encoder, decoder, selector));
    }
}

3.6.5 调用远程服务的代理类

/**
 * 调用远程服务的代理类
 */
@Slf4j
public class RemoteInvoker implements InvocationHandler {
    private Class clazz;
    private Encoder encoder;
    private Decoder decoder;
    private TransportSelector selector;

    RemoteInvoker(Class clazz, Encoder encoder, Decoder decoder, TransportSelector selector) {
        this.clazz = clazz;
        this.encoder = encoder;
        this.decoder = decoder;
        this.selector = selector;
    }

    //调用远程服务:构造一个服务,通过网络把服务发送给Server,等待Server的响应,从响应里拿到响应数据,这次调用就结束了
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        //构造服务
        Request request = new Request();
        request.setService(ServiceDescriptor.from(clazz, method));
        request.setParameters(args);

        //通过网络传输调用远程服务
        Response response = invokeRemote(request);
        if (response == null || response.getCode() != 0) {
            throw new IllegalStateException("fail to invoke remote: " + response);
        }
        return response.getData();
    }

    private Response invokeRemote(Request request) {
        //网络连接信息
        TransportClient client = null;
        Response response = null;
        try {
            client = selector.select();//选择一个client
            byte[] bytes = encoder.encode(request);//请求序列化成二进制
            InputStream in = client.write(new ByteArrayInputStream(bytes));
            byte[] inBytes = IOUtils.readFully(in, in.available(), true);
            response = decoder.decode(inBytes, Response.class);//二进制反序列化成响应
        } catch (Exception e) {
            log.error(e.getMessage(), e);
            response.setCode(-1);
            response.setMessage("RpcClient got error:" + e.getClass() + ":" + e.getMessage());
        } finally {
            if (client != null) {
                selector.release(client);
            }
        }
        return response;
    }
}

3.7 gk-rpc使用案例

3.7.0 引入依赖

    <dependencies>
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-client</artifactId>
            <version>1.0-SNAPSHOT</version>
        </dependency>
        <dependency>
            <groupId>com.wx</groupId>
            <artifactId>gk-rpc-server</artifactId>
            <version>1.0-SNAPSHOT</version>
        </dependency>
    </dependencies>

Client

public class Client {
    public static void main(String[] args) {
        RpcClient client = new RpcClient();
        CalcService service = client.getProxy(CalcService.class);
        int result = service.add(1, 2);
        System.out.println(result);
    }
}

Server

public class Server {
    public static void main(String[] args) {
        RpcServer server = new RpcServer();
        server.register(CalcService.class, new CalcServiceImpl());
        server.start();
    }
}

CalcService

public interface CalcService {
    int add(int a,int b);
}

CalcServiceImpl

public class CalcServiceImpl implements CalcService{
    @Override
    public int add(int a, int b) {
        return a+b;
    }
}

先把server跑起来,再跑Client,就能得到结果

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值