dubbo解析-扩展RPC协议(基于dubbo协议优化)

本文基于dubbo 2.7.5版本代码

在文章《dubbo解析-从dubbo协议分析如何设计RPC协议》中详细介绍dubbo协议的组成,也提出了我对dubbo协议的一些优化想法。本文通过对dubbo协议的编解码器扩展,实现一个优化后的协议。由于dubbo将协议与编解码器绑定了,所以本文还扩展了dubbo协议。

一、优化内容

  1. 对报文体压缩,在报文头使用四个bit表示压缩格式,这样报文体可以使用的压缩算法最多可以有16种;
  2. 优化status表示,status共有10个状态,4个bit可以表示16种状态,因此不在使用8个bit表示status字段,而使用4个bit;
  3. RPC Request ID使用变长表示,RPC Request ID在原dubbo协议中使用long表示,long使用最高位表示正或者负,因为ID只能为正,所以最高位永远是0,我们可以利用最高位来表示ID占用的字节数,当最高位是0时,表示ID占了4个字节,当最高位是1时表示占了8个字节。

基于上述三点优化,原协议的status字段拆为两部分,第一部分四个bit表示压缩算法,第二部分4个bit表示status。优化后的status值与原status值之间的对应关系如下:

原status值优化后的status值
2000
3001
3102
4003
5004
6005
7006
8007
9008
10009

由于RPC Request ID使用变长表示,因此优化后的报文头变为可变长的,最短占12个字节,最长占16个字节。
优化后的协议名字叫做:dubboex

二、实施方案

我们需要在org.apache.dubbo.rpc.Protocol文件里面增加dubboex协议及其实现类。
我们还需要在org.apache.dubbo.remoting.Codec2文件中增加dubboex协议的编解码器。
dubbo协议是通过DubboCodec类实现的,DubboCodec类继承自ExchangeCodec,DubboCodec主要处理了报文体,它重载了ExchangeCodec里面处理报文体的方法,ExchangeCodec主要处理了报文头,因此dubboex协议的实现类也继承ExchangeCodec,对其中处理报文头和报文体的方法重载。
最后在spring的配置文件中指定:dubbo.protocol.name=dubboex,这样便可以使用dubboex协议。

三、编写dubboex协议实现类

dubboex协议编解码器实现类命名为DubboexCodec。
ExchangeCodec里面需要重新的方法有:
1、public Object decode(Channel channel, ChannelBuffer buffer) throws IOException;
2、void encodeResponse(Channel channel, ChannelBuffer buffer, Response res) throws IOException;
3、void encodeRequest(Channel channel, ChannelBuffer buffer, Request req) throws IOException。
三个方法的作用分别是:
1、decode方法将收到的响应报文或者请求报文解析成Response对象或者Request对象;
2、encodeResponse组装响应报文;
3、encodeRequest组装请求报文。
在本小节实现的时候,压缩算法只支持gzip,其他暂不支持,表示压缩算法类型的字段记为0。
下面是DubboexCodec类的完整实现。

package com.dubbo.protocol;

import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.dubbo.common.Version;
import org.apache.dubbo.common.io.Bytes;
import org.apache.dubbo.common.io.StreamUtils;
import org.apache.dubbo.common.io.UnsafeByteArrayInputStream;
import org.apache.dubbo.common.logger.Logger;
import org.apache.dubbo.common.logger.LoggerFactory;
import org.apache.dubbo.common.serialize.Cleanable;
import org.apache.dubbo.common.serialize.ObjectInput;
import org.apache.dubbo.common.serialize.ObjectOutput;
import org.apache.dubbo.common.serialize.Serialization;
import org.apache.dubbo.common.utils.StringUtils;
import org.apache.dubbo.remoting.Channel;
import org.apache.dubbo.remoting.RemotingException;
import org.apache.dubbo.remoting.buffer.ChannelBuffer;
import org.apache.dubbo.remoting.buffer.ChannelBufferInputStream;
import org.apache.dubbo.remoting.buffer.ChannelBufferOutputStream;
import org.apache.dubbo.remoting.exchange.Request;
import org.apache.dubbo.remoting.exchange.Response;
import org.apache.dubbo.remoting.exchange.codec.ExchangeCodec;
import org.apache.dubbo.remoting.transport.CodecSupport;
import org.apache.dubbo.remoting.transport.ExceedPayloadLimitException;
import org.apache.dubbo.rpc.Invocation;
import org.apache.dubbo.rpc.protocol.dubbo.DecodeableRpcInvocation;
import org.apache.dubbo.rpc.protocol.dubbo.DecodeableRpcResult;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringBufferInputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.zip.DeflaterInputStream;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import static org.apache.dubbo.rpc.protocol.dubbo.Constants.DECODE_IN_IO_THREAD_KEY;
import static org.apache.dubbo.rpc.protocol.dubbo.Constants.DEFAULT_DECODE_IN_IO_THREAD;

public class DubboexCodec extends ExchangeCodec {
    protected static final short MAGIC_EX = (short) 0xdabc;//使用新的魔数,在原魔数上+1
    protected static final byte MAGIC_HIGH_EX = Bytes.short2bytes(MAGIC_EX)[0];
    protected static final byte MAGIC_LOW_EX = Bytes.short2bytes(MAGIC_EX)[1];
    protected static final byte STATUS_MASK = (byte)0x0f;
    protected static final byte COMPRESS_MASK = (byte)0xf0;
    //优化后的status值与原status值之间的对应关系
    private static Map<Byte,Byte> STATUS_NEW_OLD_MAP=new HashMap<Byte,Byte>();
    private static Map<Byte,Byte> STATUS_OLD_NEW_MAP=new HashMap<Byte,Byte>();
    static{
        STATUS_NEW_OLD_MAP.put((byte)0,  (byte)20);        STATUS_NEW_OLD_MAP.put((byte)1,(byte)30);
        STATUS_NEW_OLD_MAP.put((byte)2,(byte)31);        STATUS_NEW_OLD_MAP.put((byte)3,(byte)40);
        STATUS_NEW_OLD_MAP.put((byte)4,(byte)50);        STATUS_NEW_OLD_MAP.put((byte)5,(byte)60);
        STATUS_NEW_OLD_MAP.put((byte)6,(byte)70);        STATUS_NEW_OLD_MAP.put((byte)7,(byte)80);
        STATUS_NEW_OLD_MAP.put((byte)8,(byte)90);        STATUS_NEW_OLD_MAP.put((byte)9,(byte)100);

        STATUS_OLD_NEW_MAP.put((byte)20,  (byte)2);        STATUS_OLD_NEW_MAP.put((byte)30,(byte)1);
        STATUS_OLD_NEW_MAP.put((byte)31,(byte)2);        STATUS_OLD_NEW_MAP.put((byte)40,(byte)3);
        STATUS_OLD_NEW_MAP.put((byte)50,(byte)4);        STATUS_OLD_NEW_MAP.put((byte)60,(byte)5);
        STATUS_OLD_NEW_MAP.put((byte)70,(byte)6);        STATUS_OLD_NEW_MAP.put((byte)80,(byte)7);
        STATUS_OLD_NEW_MAP.put((byte)90,(byte)8);        STATUS_OLD_NEW_MAP.put((byte)100,(byte)9);
    }
    private static final Logger logger = LoggerFactory.getLogger(DubboexCodec.class);
    @Override
    public Object decode(Channel channel, ChannelBuffer buffer) throws IOException {
        int readable = buffer.readableBytes();
        byte[] header = new byte[Math.min(readable, HEADER_LENGTH-4)];//这里依旧获取16个字节的数据
        buffer.readBytes(header);
        return decode(channel, buffer, readable, header);
    }

    public Object decode(Channel channel, ChannelBuffer buffer, int readable, byte[] header) throws IOException {
        if (readable > 0 && header[0] != MAGIC_HIGH_EX
                 || readable > 1 && header[1] != MAGIC_LOW_EX) {
            //如果魔数不符合要求,数据直接丢弃
            return DecodeResult.NEED_MORE_INPUT;
        }
        // 检查报文头长度,长度为16或者12字节
        if (readable < (HEADER_LENGTH-4)) {
            return DecodeResult.NEED_MORE_INPUT;
        }

        // 获取报文体长度
        //因为id是可变长的,所以先判断id的长度,然后便可知道报文体长度字段的位置
        long id = Bytes.bytes2int(header, 4);
        int len=-1;
        if(id<0) {
            len = Bytes.bytes2int(header, 12);
            byte[] temp=new byte[HEADER_LENGTH];
            System.arraycopy(header, 0, temp, 0, header.length);
            buffer.readBytes(temp,HEADER_LENGTH-4,4);
            header=temp;
        }else{
            len = Bytes.bytes2int(header, 8);
        }
        //检查报文体的长度,默认不能超过8M,可以通过ProtocolConfig的payload字段设置,
        checkPayload(channel, len);

        int tt = len + header.length;
        if (readable < tt) {
            return DecodeResult.NEED_MORE_INPUT;
        }

        // 构建报文体的输入流,ChannelBufferInputStream相当于一个外观类,它对ChannelBuffer的方法做了封装
        ChannelBufferInputStream is = new ChannelBufferInputStream(buffer, len);

        try {
            //处理报文体
            return decodeBody(channel, is, header);
        } finally {
            //将报文体的输入流中未使用的部分跳过,为下一个请求数据读取做准备
            if (is.available() > 0) {
                try {
                    StreamUtils.skipUnusedStream(is);
                } catch (IOException e) {
                    logger.warn(e.getMessage(), e);
                }
            }
        }
    }
    public Object decodeBody(Channel channel, InputStream is, byte[] header) throws IOException {
        //proto表示序列化类型
        byte flag = header[2], proto = (byte) (flag & SERIALIZATION_MASK);
        // get request id.
        long id = Bytes.bytes2int(header, 4);
        if(id<0){
            id = Bytes.bytes2long(header, 4);
        }

        if ((flag & FLAG_REQUEST) == 0) {
            // 解码响应报文,构建Response对象
            Response res = new Response(id);
            if ((flag & FLAG_EVENT) != 0) {
                res.setEvent(true);
            }
            // 获取状态status,header[3]的低四位表示status,高四位表示压缩格式
            byte status =(byte)(header[3] & STATUS_MASK);
            res.setStatus(status);
            try {
                if (STATUS_NEW_OLD_MAP.get(status) == Response.OK) {
                    Object data;
                    //先解压缩,之后反序列化
                    //解压缩
                    is=decompression(header,is);
                    if (res.isEvent()) {
                        ObjectInput in = CodecSupport.deserialize(channel.getUrl(), is, proto);
                        data = decodeEventData(channel, in);
                    } else {
                        DecodeableRpcResult result;
                        if (channel.getUrl().getParameter(DECODE_IN_IO_THREAD_KEY, DEFAULT_DECODE_IN_IO_THREAD)) {
                            //在IO线程里面直接解码
                            result = new DecodeableRpcResult(channel, res, is,
                                     (Invocation) getRequestData(id), proto);
                            result.decode();
                        } else {
                            //在异步线程里面解码
                            result = new DecodeableRpcResult(channel, res,
                                     new UnsafeByteArrayInputStream(readMessageData(is)),
                                     (Invocation) getRequestData(id), proto);
                        }
                        data = result;
                    }
                    res.setResult(data);
                } else {
                    is=decompression(header,is);
                    ObjectInput in = CodecSupport.deserialize(channel.getUrl(), is, proto);
                    res.setErrorMessage(in.readUTF());
                }
            } catch (Throwable t) {
                res.setStatus(Response.CLIENT_ERROR);
                res.setErrorMessage(StringUtils.toString(t));
            }
            return res;
        } else {
            // 解码请求报文,构建Request对象
            Request req = new Request(id);
            req.setVersion(Version.getProtocolVersion());
            req.setTwoWay((flag & FLAG_TWOWAY) != 0);
            if ((flag & FLAG_EVENT) != 0) {
                req.setEvent(true);
            }
            try {
                is=decompression(header,is);
                Object data;
                if (req.isEvent()) {
                    ObjectInput in = CodecSupport.deserialize(channel.getUrl(), is, proto);
                    data = decodeEventData(channel, in);
                } else {
                    DecodeableRpcInvocation inv;
                    if (channel.getUrl().getParameter(DECODE_IN_IO_THREAD_KEY, DEFAULT_DECODE_IN_IO_THREAD)) {
                        //在IO线程里面解码请求报文
                        inv = new DecodeableRpcInvocation(channel, req, is, proto);
                        inv.decode();
                    } else {
                        inv = new DecodeableRpcInvocation(channel, req,
                                 new UnsafeByteArrayInputStream(readMessageData(is)), proto);
                    }
                    data = inv;
                }
                req.setData(data);
            } catch (Throwable t) {
                req.setBroken(true);
                req.setData(t);
            }
            return req;
        }
    }

    @Override
    public void encodeResponse(Channel channel, ChannelBuffer buffer, Response res) throws IOException{
        int savedWriteIndex = buffer.writerIndex();
        try {
            Serialization serialization = getSerialization(channel);
            // header.
            byte[] header = new byte[HEADER_LENGTH];
            // set magic number.
            Bytes.short2bytes(MAGIC_EX, header);
            // set request and serialization flag.
            header[2] = serialization.getContentTypeId();
            if (res.isHeartbeat()) {
                header[2] |= FLAG_EVENT;
            }
            // set response status.
            byte status =(byte) (STATUS_OLD_NEW_MAP.get(res.getStatus())&0xff);
            header[3] =status;
            // set request id.
            if(res.getId()>=1L<<31) {
                Bytes.long2bytes(res.getId(), header, 4);
            }else{
                Bytes.int2bytes((int)res.getId(), header, 4);
                //header的长度缩短
                byte[] temp=new byte[HEADER_LENGTH-4];
                System.arraycopy(header, 0, temp, 0, temp.length);
                header=temp;
            }

            buffer.writerIndex(savedWriteIndex + header.length);
            ChannelBufferOutputStream bos = new ChannelBufferOutputStream(buffer);
            GZIPOutputStream dos=new GZIPOutputStream(bos);
            ObjectOutput out = serialization.serialize(channel.getUrl(), dos);

            // encode response data or error message.
            if (res.getStatus() == Response.OK) {
                if (res.isHeartbeat()) {
                    encodeEventData(channel, out, res.getResult());
                } else {
                    encodeResponseData(channel, out, res.getResult(), res.getVersion());
                }
            } else {
                out.writeUTF(res.getErrorMessage());
            }
            out.flushBuffer();
            if (out instanceof Cleanable) {
                ((Cleanable) out).cleanup();
            }
            bos.flush();
            bos.close();

            int len = bos.writtenBytes();
            checkPayload(channel, len);
            if(header.length==HEADER_LENGTH) {
                Bytes.int2bytes(len, header, 12);
            }else{
                Bytes.int2bytes(len, header, 8);
            }
            // write
            buffer.writerIndex(savedWriteIndex);
            buffer.writeBytes(header); // write header.
            buffer.writerIndex(savedWriteIndex + header.length + len);
        } catch (Throwable t) {
            // clear buffer
            buffer.writerIndex(savedWriteIndex);
            // send error message to Consumer, otherwise, Consumer will wait till timeout.
            if (!res.isEvent() && res.getStatus() != Response.BAD_RESPONSE) {
                Response r = new Response(res.getId(), res.getVersion());
                r.setStatus(Response.BAD_RESPONSE);

                if (t instanceof ExceedPayloadLimitException) {
                    logger.warn(t.getMessage(), t);
                    try {
                        r.setErrorMessage(t.getMessage());
                        channel.send(r);
                        return;
                    } catch (RemotingException e) {
                        logger.warn("Failed to send bad_response info back: " + t.getMessage() + ", cause: " + e.getMessage(), e);
                    }
                } else {
                    // FIXME log error message in Codec and handle in caught() of IoHanndler?
                    logger.warn("Fail to encode response: " + res + ", send bad_response info instead, cause: " + t.getMessage(), t);
                    try {
                        r.setErrorMessage("Failed to send response: " + res + ", cause: " + StringUtils.toString(t));
                        channel.send(r);
                        return;
                    } catch (RemotingException e) {
                        logger.warn("Failed to send bad_response info back: " + res + ", cause: " + e.getMessage(), e);
                    }
                }
            }

            // Rethrow exception
            if (t instanceof IOException) {
                throw (IOException) t;
            } else if (t instanceof RuntimeException) {
                throw (RuntimeException) t;
            } else if (t instanceof Error) {
                throw (Error) t;
            } else {
                throw new RuntimeException(t.getMessage(), t);
            }
        }
    }

    @Override
    public void encodeRequest(Channel channel, ChannelBuffer buffer, Request req) throws IOException {
        Serialization serialization = getSerialization(channel);
        // header.
        byte[] header = new byte[HEADER_LENGTH];
        // set magic number.
        Bytes.short2bytes(MAGIC_EX, header);
        // set request and serialization flag.
        header[2] = (byte) (FLAG_REQUEST | serialization.getContentTypeId());
        if (req.isTwoWay()) {
            header[2] |= FLAG_TWOWAY;
        }
        if (req.isEvent()) {
            header[2] |= FLAG_EVENT;
        }

        if(req.getId()>=1L<<31) {
            Bytes.long2bytes(req.getId(), header, 4);
        }else{
            Bytes.int2bytes((int)req.getId(), header, 4);
            //header的长度缩短
            byte[] temp=new byte[HEADER_LENGTH-4];
            System.arraycopy(header, 0, temp, 0, temp.length);
            header=temp;
        }

        // encode request data.
        int savedWriteIndex = buffer.writerIndex();
        buffer.writerIndex(savedWriteIndex + header.length);
        ChannelBufferOutputStream bos = new ChannelBufferOutputStream(buffer);
        GZIPOutputStream dos=new GZIPOutputStream(bos);
        ObjectOutput out = serialization.serialize(channel.getUrl(), dos);
        if (req.isEvent()) {
            encodeEventData(channel, out, req.getData());
        } else {
            encodeRequestData(channel, out, req.getData(), req.getVersion());
        }
        out.flushBuffer();
        dos.close();
        if (out instanceof Cleanable) {
            ((Cleanable) out).cleanup();
        }
        bos.flush();
        bos.close();
        int len = bos.writtenBytes();
        checkPayload(channel, len);
        if(header.length==HEADER_LENGTH) {
            Bytes.int2bytes(len, header, 12);
        }else{
            Bytes.int2bytes(len, header, 8);
        }

        // write
        buffer.writerIndex(savedWriteIndex);
        buffer.writeBytes(header); // write header.
        buffer.writerIndex(savedWriteIndex + header.length + len);
    }

    private InputStream decompression(byte[] header,InputStream is)throws IOException{
        byte decompression=(byte)(header[3] & COMPRESS_MASK);
        //目前只支持gzip压缩
        GZIPInputStream dis=new GZIPInputStream(is);
        return dis;
    }
    private byte[] readMessageData(InputStream is) throws IOException {
        if (is.available() > 0) {
            byte[] result = new byte[is.available()];
            is.read(result);
            return result;
        }
        return new byte[]{};
    }
    private void encodeEventData(Channel channel, ObjectOutput out, Object data) throws IOException {
        out.writeEvent(data);
    }
}


因为dubbo将协议与编解码器绑定了,所以要使上述代码生效,我们还需要扩展Protocol实现。为此我修改了DubboProtocol类,创建了DubboExProtocol类,在DubboExProtocol中指定使用DubboexCodec。
DubboExProtocol类代码如下,修改的地方用"[数字]"做了标示,有部分代码做了删减:

package com.dubbo.protocol;

import org.apache.dubbo.common.URL;
import org.apache.dubbo.common.URLBuilder;
import org.apache.dubbo.common.config.ConfigurationUtils;
import org.apache.dubbo.common.extension.ExtensionLoader;
import org.apache.dubbo.common.serialize.support.SerializableClassRegistry;
import org.apache.dubbo.common.serialize.support.SerializationOptimizer;
import org.apache.dubbo.common.utils.*;
import org.apache.dubbo.remoting.Channel;
import org.apache.dubbo.remoting.RemotingException;
import org.apache.dubbo.remoting.RemotingServer;
import org.apache.dubbo.remoting.Transporter;
import org.apache.dubbo.remoting.exchange.*;
import org.apache.dubbo.remoting.exchange.support.ExchangeHandlerAdapter;
import org.apache.dubbo.rpc.*;
import org.apache.dubbo.rpc.protocol.AbstractProtocol;
import org.apache.dubbo.rpc.protocol.dubbo.DubboExporter;
import org.apache.dubbo.rpc.protocol.dubbo.DubboInvoker;
import org.apache.dubbo.rpc.protocol.dubbo.DubboProtocolServer;
import org.apache.dubbo.rpc.protocol.dubbo.ReferenceCountExchangeClient;

import java.net.InetSocketAddress;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

import static org.apache.dubbo.common.constants.CommonConstants.*;
import static org.apache.dubbo.remoting.Constants.*;
import static org.apache.dubbo.rpc.Constants.*;
import static org.apache.dubbo.rpc.protocol.dubbo.Constants.*;


/**
 * dubbo protocol support.
 */
public class DubboExProtocol extends AbstractProtocol {

    public static final String NAME = "dubboex";[1]

    public static final int DEFAULT_PORT = 20880;
    private static final String IS_CALLBACK_SERVICE_INVOKE = "_isCallBackServiceInvoke";
    private static DubboExProtocol INSTANCE;[2]

    /**
     * <host:port,Exchanger>
     */
    private final Map<String, List<ReferenceCountExchangeClient>> referenceClientMap = new ConcurrentHashMap<>();
    private final ConcurrentMap<String, Object> locks = new ConcurrentHashMap<>();
    private final Set<String> optimizers = new ConcurrentHashSet<>();
    /**
     * consumer side export a stub service for dispatching event
     * servicekey-stubmethods
     */
    private final ConcurrentMap<String, String> stubServiceMethodsMap = new ConcurrentHashMap<>();

    private ExchangeHandler requestHandler = new ExchangeHandlerAdapter() {

        @Override
        public CompletableFuture<Object> reply(ExchangeChannel channel, Object message) throws RemotingException {

            if (!(message instanceof Invocation)) {
                throw new RemotingException(channel, "Unsupported request: "
                         + (message == null ? null : (message.getClass().getName() + ": " + message))
                         + ", channel: consumer: " + channel.getRemoteAddress() + " --> provider: " + channel.getLocalAddress());
            }

            Invocation inv = (Invocation) message;
            Invoker<?> invoker = getInvoker(channel, inv);
            // need to consider backward-compatibility if it's a callback
            if (Boolean.TRUE.toString().equals(inv.getAttachments().get(IS_CALLBACK_SERVICE_INVOKE))) {
                String methodsStr = invoker.getUrl().getParameters().get("methods");
                boolean hasMethod = false;
                if (methodsStr == null || !methodsStr.contains(",")) {
                    hasMethod = inv.getMethodName().equals(methodsStr);
                } else {
                    String[] methods = methodsStr.split(",");
                    for (String method : methods) {
                        if (inv.getMethodName().equals(method)) {
                            hasMethod = true;
                            break;
                        }
                    }
                }
                if (!hasMethod) {
                    logger.warn(new IllegalStateException("The methodName " + inv.getMethodName()
                             + " not found in callback service interface ,invoke will be ignored."
                             + " please update the api interface. url is:"
                             + invoker.getUrl()) + " ,invocation is :" + inv);
                    return null;
                }
            }
            RpcContext.getContext().setRemoteAddress(channel.getRemoteAddress());
            Result result = invoker.invoke(inv);
            return result.thenApply(Function.identity());
        }

        @Override
        public void received(Channel channel, Object message) throws RemotingException {
            if (message instanceof Invocation) {
                reply((ExchangeChannel) channel, message);

            } else {
                super.received(channel, message);
            }
        }

        @Override
        public void connected(Channel channel) throws RemotingException {
            invoke(channel, ON_CONNECT_KEY);
        }

        @Override
        public void disconnected(Channel channel) throws RemotingException {
            if (logger.isDebugEnabled()) {
                logger.debug("disconnected from " + channel.getRemoteAddress() + ",url:" + channel.getUrl());
            }
            invoke(channel, ON_DISCONNECT_KEY);
        }

        private void invoke(Channel channel, String methodKey) {
            Invocation invocation = createInvocation(channel, channel.getUrl(), methodKey);
            if (invocation != null) {
                try {
                    received(channel, invocation);
                } catch (Throwable t) {
                    logger.warn("Failed to invoke event method " + invocation.getMethodName() + "(), cause: " + t.getMessage(), t);
                }
            }
        }

        private Invocation createInvocation(Channel channel, URL url, String methodKey) {
            String method = url.getParameter(methodKey);
            if (method == null || method.length() == 0) {
                return null;
            }

            RpcInvocation invocation = new RpcInvocation(method, url.getParameter(INTERFACE_KEY), new Class<?>[0], new Object[0]);
            invocation.setAttachment(PATH_KEY, url.getPath());
            invocation.setAttachment(GROUP_KEY, url.getParameter(GROUP_KEY));
            invocation.setAttachment(INTERFACE_KEY, url.getParameter(INTERFACE_KEY));
            invocation.setAttachment(VERSION_KEY, url.getParameter(VERSION_KEY));
            if (url.getParameter(STUB_EVENT_KEY, false)) {
                invocation.setAttachment(STUB_EVENT_KEY, Boolean.TRUE.toString());
            }

            return invocation;
        }
    };

    public DubboExProtocol() {
        INSTANCE = this;
    }

    public static DubboExProtocol getDubboProtocol() {
        if (INSTANCE == null) {
            ExtensionLoader.getExtensionLoader(Protocol.class).getExtension(DubboExProtocol.NAME);[3]
        }

        return INSTANCE;
    }

    public Collection<Exporter<?>> getExporters() {
        return Collections.unmodifiableCollection(exporterMap.values());
    }

    private boolean isClientSide(Channel channel) {
        InetSocketAddress address = channel.getRemoteAddress();
        URL url = channel.getUrl();
        return url.getPort() == address.getPort() &&
                 NetUtils.filterLocalHost(channel.getUrl().getIp())
                          .equals(NetUtils.filterLocalHost(address.getAddress().getHostAddress()));
    }

    Invoker<?> getInvoker(Channel channel, Invocation inv) throws RemotingException {
        boolean isCallBackServiceInvoke = false;
        boolean isStubServiceInvoke = false;
        int port = channel.getLocalAddress().getPort();
        String path = inv.getAttachments().get(PATH_KEY);

        // if it's callback service on client side
        isStubServiceInvoke = Boolean.TRUE.toString().equals(inv.getAttachments().get(STUB_EVENT_KEY));
        if (isStubServiceInvoke) {
            port = channel.getRemoteAddress().getPort();
        }

        //callback
        isCallBackServiceInvoke = isClientSide(channel) && !isStubServiceInvoke;
        if (isCallBackServiceInvoke) {
            path += "." + inv.getAttachments().get(CALLBACK_SERVICE_KEY);
            inv.getAttachments().put(IS_CALLBACK_SERVICE_INVOKE, Boolean.TRUE.toString());
        }

        String serviceKey = serviceKey(port, path, inv.getAttachments().get(VERSION_KEY), inv.getAttachments().get(GROUP_KEY));
        DubboExporter<?> exporter = (DubboExporter<?>) exporterMap.get(serviceKey);

        if (exporter == null) {
            throw new RemotingException(channel, "Not found exported service: " + serviceKey + " in " + exporterMap.keySet() + ", may be version or group mismatch " +
                     ", channel: consumer: " + channel.getRemoteAddress() + " --> provider: " + channel.getLocalAddress() + ", message:" + getInvocationWithoutData(inv));
        }

        return exporter.getInvoker();
    }

    public Collection<Invoker<?>> getInvokers() {
        return Collections.unmodifiableCollection(invokers);
    }

    @Override
    public int getDefaultPort() {
        return DEFAULT_PORT;
    }

    @Override
    public <T> Exporter<T> export(Invoker<T> invoker) throws RpcException {
        URL url = invoker.getUrl();

        // export service.
        String key = serviceKey(url);
        DubboExporter<T> exporter = new DubboExporter<T>(invoker, key, exporterMap);
        exporterMap.put(key, exporter);

        //export an stub service for dispatching event
        Boolean isStubSupportEvent = url.getParameter(STUB_EVENT_KEY, DEFAULT_STUB_EVENT);
        Boolean isCallbackservice = url.getParameter(IS_CALLBACK_SERVICE, false);
        if (isStubSupportEvent && !isCallbackservice) {
            String stubServiceMethods = url.getParameter(STUB_EVENT_METHODS_KEY);
            if (stubServiceMethods == null || stubServiceMethods.length() == 0) {
                if (logger.isWarnEnabled()) {
                    logger.warn(new IllegalStateException("consumer [" + url.getParameter(INTERFACE_KEY) +
                             "], has set stubproxy support event ,but no stub methods founded."));
                }

            } else {
                stubServiceMethodsMap.put(url.getServiceKey(), stubServiceMethods);
            }
        }

        openServer(url);
        optimizeSerialization(url);

        return exporter;
    }

    private void openServer(URL url) {
        // find server.
        String key = url.getAddress();
        //client can export a service which's only for server to invoke
        boolean isServer = url.getParameter(IS_SERVER_KEY, true);
        if (isServer) {
            ProtocolServer server = serverMap.get(key);
            if (server == null) {
                synchronized (this) {
                    server = serverMap.get(key);
                    if (server == null) {
                        serverMap.put(key, createServer(url));
                    }
                }
            } else {
                // server supports reset, use together with override
                server.reset(url);
            }
        }
    }

    private ProtocolServer createServer(URL url) {
        url = URLBuilder.from(url)
                 // send readonly event when server closes, it's enabled by default
                 .addParameterIfAbsent(CHANNEL_READONLYEVENT_SENT_KEY, Boolean.TRUE.toString())
                 // enable heartbeat by default
                 .addParameterIfAbsent(HEARTBEAT_KEY, String.valueOf(DEFAULT_HEARTBEAT))
                 .addParameter(CODEC_KEY, "dubboex")[4]
                 .build();
        String str = url.getParameter(SERVER_KEY, DEFAULT_REMOTING_SERVER);

        if (str != null && str.length() > 0 && !ExtensionLoader.getExtensionLoader(Transporter.class).hasExtension(str)) {
            throw new RpcException("Unsupported server type: " + str + ", url: " + url);
        }

        ExchangeServer server;
        try {
            server = Exchangers.bind(url, requestHandler);
        } catch (RemotingException e) {
            throw new RpcException("Fail to start server(url: " + url + ") " + e.getMessage(), e);
        }

        str = url.getParameter(CLIENT_KEY);
        if (str != null && str.length() > 0) {
            Set<String> supportedTypes = ExtensionLoader.getExtensionLoader(Transporter.class).getSupportedExtensions();
            if (!supportedTypes.contains(str)) {
                throw new RpcException("Unsupported client type: " + str);
            }
        }

        return new DubboProtocolServer(server);
    }

    private void optimizeSerialization(URL url) throws RpcException {
        String className = url.getParameter(OPTIMIZER_KEY, "");
        if (StringUtils.isEmpty(className) || optimizers.contains(className)) {
            return;
        }

        logger.info("Optimizing the serialization process for Kryo, FST, etc...");

        try {
            Class clazz = Thread.currentThread().getContextClassLoader().loadClass(className);
            if (!SerializationOptimizer.class.isAssignableFrom(clazz)) {
                throw new RpcException("The serialization optimizer " + className + " isn't an instance of " + SerializationOptimizer.class.getName());
            }

            SerializationOptimizer optimizer = (SerializationOptimizer) clazz.newInstance();

            if (optimizer.getSerializableClasses() == null) {
                return;
            }

            for (Class c : optimizer.getSerializableClasses()) {
                SerializableClassRegistry.registerClass(c);
            }

            optimizers.add(className);

        } catch (ClassNotFoundException e) {
            throw new RpcException("Cannot find the serialization optimizer class: " + className, e);

        } catch (InstantiationException e) {
            throw new RpcException("Cannot instantiate the serialization optimizer class: " + className, e);

        } catch (IllegalAccessException e) {
            throw new RpcException("Cannot instantiate the serialization optimizer class: " + className, e);
        }
    }

    @Override
    public <T> Invoker<T> protocolBindingRefer(Class<T> serviceType, URL url) throws RpcException {
        optimizeSerialization(url);

        // create rpc invoker.
        DubboInvoker<T> invoker = new DubboInvoker<T>(serviceType, url, getClients(url), invokers);
        invokers.add(invoker);

        return invoker;
    }

    private ExchangeClient[] getClients(URL url) {
        // whether to share connection

        boolean useShareConnect = false;

        int connections = url.getParameter(CONNECTIONS_KEY, 0);
        List<ReferenceCountExchangeClient> shareClients = null;
        // if not configured, connection is shared, otherwise, one connection for one service
        if (connections == 0) {
            useShareConnect = true;

            /**
             * The xml configuration should have a higher priority than properties.
             */
            String shareConnectionsStr = url.getParameter(SHARE_CONNECTIONS_KEY, (String) null);
            connections = Integer.parseInt(StringUtils.isBlank(shareConnectionsStr) ? ConfigUtils.getProperty(SHARE_CONNECTIONS_KEY,
                     DEFAULT_SHARE_CONNECTIONS) : shareConnectionsStr);
            shareClients = getSharedClient(url, connections);
        }

        ExchangeClient[] clients = new ExchangeClient[connections];
        for (int i = 0; i < clients.length; i++) {
            if (useShareConnect) {
                clients[i] = shareClients.get(i);

            } else {
                clients[i] = initClient(url);
            }
        }

        return clients;
    }

    /**
     * Get shared connection
     *
     * @param url
     * @param connectNum connectNum must be greater than or equal to 1
     */
    private List<ReferenceCountExchangeClient> getSharedClient(URL url, int connectNum) {
        String key = url.getAddress();
        List<ReferenceCountExchangeClient> clients = referenceClientMap.get(key);

        if (checkClientCanUse(clients)) {
            batchClientRefIncr(clients);
            return clients;
        }

        locks.putIfAbsent(key, new Object());
        synchronized (locks.get(key)) {
            clients = referenceClientMap.get(key);
            // dubbo check
            if (checkClientCanUse(clients)) {
                batchClientRefIncr(clients);
                return clients;
            }

            // connectNum must be greater than or equal to 1
            connectNum = Math.max(connectNum, 1);

            // If the clients is empty, then the first initialization is
            if (CollectionUtils.isEmpty(clients)) {
                clients = buildReferenceCountExchangeClientList(url, connectNum);
                referenceClientMap.put(key, clients);

            } else {
                for (int i = 0; i < clients.size(); i++) {
                    ReferenceCountExchangeClient referenceCountExchangeClient = clients.get(i);
                    // If there is a client in the list that is no longer available, create a new one to replace him.
                    if (referenceCountExchangeClient == null || referenceCountExchangeClient.isClosed()) {
                        clients.set(i, buildReferenceCountExchangeClient(url));
                        continue;
                    }

                    referenceCountExchangeClient.incrementAndGetCount();
                }
            }

            /**
             * I understand that the purpose of the remove operation here is to avoid the expired url key
             * always occupying this memory space.
             */
            locks.remove(key);

            return clients;
        }
    }

    /**
     * Check if the client list is all available
     *
     * @param referenceCountExchangeClients
     * @return true-available,false-unavailable
     */
    private boolean checkClientCanUse(List<ReferenceCountExchangeClient> referenceCountExchangeClients) {
        if (CollectionUtils.isEmpty(referenceCountExchangeClients)) {
            return false;
        }

        for (ReferenceCountExchangeClient referenceCountExchangeClient : referenceCountExchangeClients) {
            // As long as one client is not available, you need to replace the unavailable client with the available one.
            if (referenceCountExchangeClient == null || referenceCountExchangeClient.isClosed()) {
                return false;
            }
        }

        return true;
    }

    /**
     * Increase the reference Count if we create new invoker shares same connection, the connection will be closed without any reference.
     *
     * @param referenceCountExchangeClients
     */
    private void batchClientRefIncr(List<ReferenceCountExchangeClient> referenceCountExchangeClients) {
        if (CollectionUtils.isEmpty(referenceCountExchangeClients)) {
            return;
        }

        for (ReferenceCountExchangeClient referenceCountExchangeClient : referenceCountExchangeClients) {
            if (referenceCountExchangeClient != null) {
                referenceCountExchangeClient.incrementAndGetCount();
            }
        }
    }

    /**
     * Bulk build client
     *
     * @param url
     * @param connectNum
     * @return
     */
    private List<ReferenceCountExchangeClient> buildReferenceCountExchangeClientList(URL url, int connectNum) {
        List<ReferenceCountExchangeClient> clients = new ArrayList<>();

        for (int i = 0; i < connectNum; i++) {
            clients.add(buildReferenceCountExchangeClient(url));
        }

        return clients;
    }

    /**
     * Build a single client
     *
     * @param url
     * @return
     */
    private ReferenceCountExchangeClient buildReferenceCountExchangeClient(URL url) {
        ExchangeClient exchangeClient = initClient(url);

        return new ReferenceCountExchangeClient(exchangeClient);
    }

    /**
     * Create new connection
     *
     * @param url
     */
    private ExchangeClient initClient(URL url) {

        // client type setting.
        String str = url.getParameter(CLIENT_KEY, url.getParameter(SERVER_KEY, DEFAULT_REMOTING_CLIENT));

        url = url.addParameter(CODEC_KEY, "dubboex");[5]
        // enable heartbeat by default
        url = url.addParameterIfAbsent(HEARTBEAT_KEY, String.valueOf(DEFAULT_HEARTBEAT));

        // BIO is not allowed since it has severe performance issue.
        if (str != null && str.length() > 0 && !ExtensionLoader.getExtensionLoader(Transporter.class).hasExtension(str)) {
            throw new RpcException("Unsupported client type: " + str + "," +
                     " supported client type is " + StringUtils.join(ExtensionLoader.getExtensionLoader(Transporter.class).getSupportedExtensions(), " "));
        }

        ExchangeClient client;
        try {
            client = Exchangers.connect(url, requestHandler);
        } catch (RemotingException e) {
            throw new RpcException("Fail to create remoting client for service(" + url + "): " + e.getMessage(), e);
        }

        return client;
    }

    @Override
    public void destroy() {
        for (String key : new ArrayList<>(serverMap.keySet())) {
            ProtocolServer protocolServer = serverMap.remove(key);

            if (protocolServer == null) {
                continue;
            }

            RemotingServer server = protocolServer.getRemotingServer();

            try {
                if (logger.isInfoEnabled()) {
                    logger.info("Close dubbo server: " + server.getLocalAddress());
                }

                server.close(ConfigurationUtils.getServerShutdownTimeout());

            } catch (Throwable t) {
                logger.warn(t.getMessage(), t);
            }
        }

        for (String key : new ArrayList<>(referenceClientMap.keySet())) {
            List<ReferenceCountExchangeClient> clients = referenceClientMap.remove(key);

            if (CollectionUtils.isEmpty(clients)) {
                continue;
            }

            for (ReferenceCountExchangeClient client : clients) {
                closeReferenceCountExchangeClient(client);
            }
        }

        stubServiceMethodsMap.clear();
        super.destroy();
    }

    /**
     * close ReferenceCountExchangeClient
     *
     * @param client
     */
    private void closeReferenceCountExchangeClient(ReferenceCountExchangeClient client) {
        if (client == null) {
            return;
        }

        try {
            if (logger.isInfoEnabled()) {
                logger.info("Close dubbo connect: " + client.getLocalAddress() + "-->" + client.getRemoteAddress());
            }

            client.close(ConfigurationUtils.getServerShutdownTimeout());

            // TODO
            /**
             * At this time, ReferenceCountExchangeClient#client has been replaced with LazyConnectExchangeClient.
             * Do you need to call client.close again to ensure that LazyConnectExchangeClient is also closed?
             */

        } catch (Throwable t) {
            logger.warn(t.getMessage(), t);
        }
    }

    /**
     * only log body in debugger mode for size & security consideration.
     *
     * @param invocation
     * @return
     */
    private Invocation getInvocationWithoutData(Invocation invocation) {
        if (logger.isDebugEnabled()) {
            return invocation;
        }
        if (invocation instanceof RpcInvocation) {
            RpcInvocation rpcInvocation = (RpcInvocation) invocation;
            rpcInvocation.setArguments(null);
            return rpcInvocation;
        }
        return invocation;
    }
}

创建完上述两个类之后,要使程序运行,还需要创建下面两个文件,这两个文件分别指定了协议和编解码器的实现类。
在这里插入图片描述
文件org.apache.dubbo.rpc.Protocol的内容如下:

dubboex=com.dubbo.protocol.DubboExProtocol

文件org.apache.dubbo.remoting.Codec2的内容如下:

dubboex=com.dubbo.protocol.DubboexCodec

因为DubboExProtocol类是复制的DubboProtocol,DubboProtocol使用了ReferenceCountExchangeClient类,所以我们需要复制ReferenceCountExchangeClient类到我们的应用程序中。
最后在spring的配置文件中指定:dubbo.protocol.name=dubboex。。
将上面提到的类和文件在消费端和服务端分别放置一份,这样便可以启动dubbo了,dubbo启动后做服务调用时,便可以使用我们创建的dubboex协议了。

上面的程序运行还有些问题,消息可以正常发送并且解析,但是反序列化有些问题,导致服务调用不成功,程序以后会再优化。

参考文章:
https://blog.csdn.net/u014269285/article/details/86360462

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值