用java实现RPC编解码



前言

编解码器在RPC框架中处于最基础也是最重要的部分之一,上一篇文章的编写中,并没有对数据进行过深入的编解码,只是使用字符串进行序列化然后进行的传输,在实际的项目中是不能这么做的,所以这篇文章主要对编解码这块进行了优化,为了保证代码的完整性,会将所有的代码在这片文章中进行粘贴

一、jar包引用

由于不在将对象转换成字符串进行传输,所以不在需要使用fastJson,而是使用对象序列化工具类

 		<dependency>
            <groupId>com.caucho</groupId>
            <artifactId>hessian</artifactId>
            <version>3.1.5</version>
        </dependency>

二、公共基础类

1.定义一个请求协议封装类

import java.util.Arrays;

public class RequestObject {
    /**
     * 接口名称
     */
    private String interfaceName;
    /**
     * 方法名称
     */
    private String methodName;
    /**
     * 方法参数签名
     */
    private String[] argsSig;
    /**
     * 方法参数值,方法参数可能是一个复杂的对象,所以添加transient关机字在对对象进行序列化时,参数不
     * 参加序列化,而是单独进行序列化
     */
    private transient Object[] args;
    
    public RequestObject(String interfaceName, Method method, Object[] args) {
        this.interfaceName = interfaceName;
        this.methodName = method.getName();
        this.args = args;
        String[] argsSig = new String[args.length];
        for (int i = 0; i < args.length; i++) {
            argsSig[i] = args[i].getClass().getTypeName();
        }
        this.argsSig = argsSig;
    }
}

2.定义一个RPC请求对象

import java.util.Arrays;

public class RpcRequest {
    private RequestObject requestObject;
    /**
     * 总长度
     */
    private int count = 8;
    /**
     * 接口名称长度
     */
    private short interfaceNameLen;
    /**
     * 方法名称长度
     */
    private short methodNameLen;
    /**
     * 包体长度
     */
    private int contentLen;
    /**
     * 方法名称序列化数据
     */
    private byte[] interfaceNameByte;
    /**
     * 接口名称序列化数据
     */
    private byte[] methodNameByte;
    /**
     * 包体序列化数据
     */
    private byte[] contentByte;
    }

3.定义响应数据对象

public class ResponseObject {
    /**
     * 返回对象签名
     */
    private String responseSig;
    /**
     * 返回值
     */
    private Object result;

    public ResponseObject(Object result) {
        this.result = result;
        this.responseSig = result.getClass().getTypeName();
    }
}

4.定义响应数据包体对象

public class RpcResponse {
    /**
     * 响应长度
     */
    private int responseLen = 4;
    /**
     * 包体
     */
    private byte[] content;

    public void setContent(byte[] content) {
        this.content = content;
        if (null != content) {
            this.responseLen+= content.length;
        }
    }
}

5.定义一个自定义的SerializerFactory

import com.caucho.hessian.io.JavaSerializer;
import com.caucho.hessian.io.Serializer;
import com.caucho.hessian.io.SerializerFactory;

public class MySerializerFactory extends SerializerFactory {
    @Override
    protected Serializer getDefaultSerializer(Class cl) {
        if (this._defaultSerializer != null) {
            return this._defaultSerializer;
        }
        return new JavaSerializer(cl);
    }
}

三、服务消费端代码编写(Consumer)

public class RequestProxyHandler implements InvocationHandler, Serializable {
    /**
     * 发送数据使用的socket
     */
    private final Socket socket;
    /**
     * 被代理的对象信息
     */
    private final String interfaceName;

    public RequestProxyHandler(Socket socket, String interfaceName) {
        this.socket = socket;
        this.interfaceName = interfaceName;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        /**
         * 1.构建出请求协议对象
         */
        RequestObject requestObject = new RequestObject(interfaceName, method, args);
        RpcRequest rpcRequest = new RpcRequest(requestObject);
        /**
         * 获取socket的输出流
         */
        OutputStream outputStream = socket.getOutputStream();
        /**
         * 序列化接口名称
         */
        String interfaceName = requestObject.getInterfaceName();
        rpcRequest.setInterfaceNameByte(interfaceName.getBytes(StandardCharsets.UTF_8));
        /**
         * 序列化接口名称
         */
        String methodName = requestObject.getMethodName();
        rpcRequest.setMethodNameByte(methodName.getBytes(StandardCharsets.UTF_8));
        /**
         * 序列化包体
         */
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        Hessian2Output hessian2Output = new Hessian2Output(byteArrayOutputStream);
        /**
         * 由于Hessian2Output默认的序列化工厂会检查类是否实现了Serializable接口,所采用自己实现的序列化工厂
         */
        hessian2Output.setSerializerFactory(new MySerializerFactory());
        try {
            hessian2Output.writeObject(requestObject);
            if (args.length > 0) {
                for (Object arg : requestObject.getArgs()) {
                    hessian2Output.writeObject(arg);
                }
            }
            hessian2Output.close();
            rpcRequest.setContentByte(byteArrayOutputStream.toByteArray());
            byteArrayOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        /**
         * 使用byteBuffer对byte[]进行拼接
         */
        ByteBuffer byteBuffer = ByteBuffer.allocate(rpcRequest.getCount());
        byteBuffer.putShort(rpcRequest.getInterfaceNameLen())
                .putShort(rpcRequest.getMethodNameLen()).putInt(rpcRequest.getContentLen())
                .put(rpcRequest.getInterfaceNameByte()).put(rpcRequest.getMethodNameByte()).put(rpcRequest.getContentByte());
        /**
         * 将转换好的byte[]写入到输出流中
         */
        outputStream.write(byteBuffer.array());
        /**
         * 调用flush方法将数据全部发送出去
         */
        outputStream.flush();
        /**
         * 阻塞获取服务端返回的数据流
         */
        InputStream inputStream = socket.getInputStream();
        /**
         * 读取包体长度
         */
        byte[] contentLength = new byte[4];
        inputStream.read(contentLength);
        ByteBuffer lengthBuffer = ByteBuffer.wrap(contentLength);
        /**
         * 读取整个包体
         */
        byte[] content = new byte[lengthBuffer.getInt()];
        inputStream.read(content);
        /**
         * 进行反序列化
         */
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(content);
        Hessian2Input hessian2Input = new Hessian2Input(byteArrayInputStream);
        hessian2Input.setSerializerFactory(new MySerializerFactory());
        Object object = hessian2Input.readObject();
        if (object instanceof ResponseObject) {
            ResponseObject responseObject = (ResponseObject) object;
            return responseObject.getResult();
        }
        /**
         * 如果没有拿到返回值则直接返回空,这里也可以根据实际业务需求做其他操作,或者抛出异常
         */
        return null;
    }
}

消费端测试代码

public class Consumer {
    public static void main(String[] args) throws IOException {
        /**
         * 创建socket链接
         */
        Socket socket = new Socket("127.0.0.1", 8888);
        /**
         * 创建动态代理对象
         */
        HelloRpc helloRpc = (HelloRpc) Proxy.newProxyInstance(HelloRpc.class.getClassLoader()
                , new Class[]{HelloRpc.class}, new RequestProxyHandler(socket, HelloRpc.class.getTypeName()));
        /**
         * 获取到返回结果并打印
         */
        String result = helloRpc.hello("张三");
        System.out.println(result);
    }
}

四.生产者端代码

public class Provider {
    public static void main(String[] arg) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException, ClassNotFoundException {
        /**
         * 用map存储接口和实现类的对于信息
         */
        Map<String, Object> providerCache = new HashMap<>();
        /**
         * 将测试接口和实现类信息,注册到map中
         */
        providerCache.put(HelloRpc.class.getName(), new HelloRpcImpl());
        /**
         * 创建ServerSocket,绑定8888端口
         */
        ServerSocket serverSocket = new ServerSocket(8888);
        while (true) {
            /**
             * 等待客户端链接
             */
            Socket clientSocket = serverSocket.accept();
            /**
             * 获取数据流
             */
            InputStream inputStream = clientSocket.getInputStream();
            /**
             * 读取字段长度信息
             */
            byte[] length = new byte[8];
            inputStream.read(length);
            ByteBuffer allLengthByte = ByteBuffer.wrap(length);
            /**
             * 读取接口名称
             */
            byte[] interfaceNameByte = new byte[allLengthByte.getShort()];
            inputStream.read(interfaceNameByte);
            String interfaceName = new String(interfaceNameByte, StandardCharsets.UTF_8);
            /**
             * 如果接口名称为空字符串或者不在缓存中则结束本次请求处理
             */
            if ("".equals(interfaceName) || null == providerCache.get(interfaceName)) {
                continue;
            }
            /**
             * 读取接口名称
             */
            byte[] methodNameByte = new byte[allLengthByte.getShort()];
            String methodName = new String(methodNameByte, StandardCharsets.UTF_8);
            /**
             * 读取包体
             */
            byte[] content = new byte[allLengthByte.getInt()];
            Hessian2Input hessian2Input = new Hessian2Input(new ByteArrayInputStream(content));
            hessian2Input.setSerializerFactory(new MySerializerFactory());
            Object object = hessian2Input.readObject();
            if (object instanceof RequestObject) {
                RequestObject requestObject = (RequestObject) object;
                String[] argsSig = requestObject.getArgsSig();
                Object[] args = new Object[argsSig.length];
                Class<?>[] argsClazz = new Class<?>[argsSig.length];
                for (int i = 0; i < argsSig.length; i++) {
                    /**
                     * 这里copy spring的ClassUtils进行类的加载
                     */
                    Class<?> aClass = ClassUtils.forName(argsSig[i], Thread.currentThread().getContextClassLoader());
                    argsClazz[i] = aClass;
                    args[i] = hessian2Input.readObject(aClass);
                }
                requestObject.setArgs(args);
                Object o = providerCache.get(interfaceName);
                Method declaredMethod = o.getClass().getDeclaredMethod(methodName, argsClazz);
                Object result = declaredMethod.invoke(o, args);
                /**
                 * 构建响应对象
                 */
                ResponseObject responseObject = new ResponseObject(result);
                RpcResponse rpcResponse = new RpcResponse();
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                Hessian2Output hessian2Output = new Hessian2Output(byteArrayOutputStream);
                hessian2Output.setSerializerFactory(new MySerializerFactory());
                hessian2Output.writeObject(responseObject);
                hessian2Output.close();
                rpcResponse.setContent(byteArrayOutputStream.toByteArray());
                byteArrayOutputStream.close();
                ByteBuffer byteBuffer = ByteBuffer.allocate(rpcResponse.getResponseLen());
                ByteBuffer put = byteBuffer.putInt(rpcResponse.getResponseLen()).put(rpcResponse.getContent());
                OutputStream outputStream = clientSocket.getOutputStream();
                outputStream.write(put.array());
                outputStream.flush();
            }

        }
    }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值