手写RPC框架--基于Netty和ZooKeeper

1.主要功能

  • 注册中心:使用ZooKeeper作为注册中心,为服务提供者暴露服务,为服务消费者提供服务。
  • 服务提供者:将Server端的服务发布到注册中心,为消费者提供服务。
  • 服务消费者:通过代理对象(JDK动态代理)代理服务,封装请求信息,从注册中心拉取服务列表,使
    用负载均衡算法选取服务。
  • 网络传输:基于Netty的NIO网络请求,代替传统BIO模型提升传输性能,提供心跳机制确保长连接的有
    效性。
  • 负载均衡:实现了随机法、轮询法、一致性Hash算法。采用FNV1_32_HASH算法代替JDK的hashCode
    方法,使哈希散列更均匀。
  • 编码解码:自定义编码解码器,向消息头中添加消息长度解决粘包问题,实现了JDK、Json两种序列化

2.架构图

在这里插入图片描述

3.目录结构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4.代码

客户端

package client;

import common.RPCRequest;
import common.RPCResponse;
import lombok.AllArgsConstructor;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

@AllArgsConstructor
public class ClientProxy implements InvocationHandler {
    private RPCClient client;

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        RPCRequest request = RPCRequest.builder()
                .interfaceName(method.getDeclaringClass().getName())
                .methodName(method.getName())
                .paramsTypes(method.getParameterTypes())
                .params(args).build();
        RPCResponse response = client.sendRequest(request);
        return response.getData();
    }

    <T>T getProxy(Class<T> tClass){
        Object o = Proxy.newProxyInstance(tClass.getClassLoader(), new Class[]{tClass}, this);
        return (T)o;
    }
}

package client;

import common.RPCRequest;
import common.RPCResponse;

public interface RPCClient {
    RPCResponse sendRequest(RPCRequest request);
}

package client.netty;

import client.RPCClient;
import common.RPCRequest;
import common.RPCResponse;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.AttributeKey;
import register.ServiceRegister;
import register.ZkServiceRegister;

import java.net.InetSocketAddress;

public class NettyClient implements RPCClient {
    private static final Bootstrap BOOTSTRAP;
    private static final EventLoopGroup EVENT_LOOP_GROUP;
    private String host;
    private int port;
    private ServiceRegister serviceRegister;

    public NettyClient(){
        serviceRegister = new ZkServiceRegister();
    }
    static {
        EVENT_LOOP_GROUP = new NioEventLoopGroup();
        BOOTSTRAP = new Bootstrap();
        BOOTSTRAP.group(EVENT_LOOP_GROUP).channel(NioSocketChannel.class)
                .handler(new NettyClientInitializer());
    }
    @Override
    public RPCResponse sendRequest(RPCRequest request) {
        InetSocketAddress address = serviceRegister.serviceDiscovery(request.getInterfaceName());
        this.host = address.getHostName();
        this.port = address.getPort();
        try {
            ChannelFuture channelFuture = BOOTSTRAP.connect(host, port).sync();
            Channel channel = channelFuture.channel();
            channel.writeAndFlush(request);
            channel.closeFuture().sync();
            AttributeKey<RPCResponse> key = AttributeKey.valueOf("RPCResponse");
            RPCResponse response = channel.attr(key).get();
            System.out.println(response);
            return response;
        } catch (InterruptedException e) {
            e.printStackTrace();
            return null;
        }
    }
}

package client.netty;

import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import serializer.JsonSerializer;
import serializer.MyDecode;
import serializer.MyEncode;

public class NettyClientInitializer extends ChannelInitializer<SocketChannel> {
    @Override
    protected void initChannel(SocketChannel socketChannel) throws Exception {
        ChannelPipeline pipeline = socketChannel.pipeline();
//        pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
//        pipeline.addLast(new LengthFieldPrepender(4));
//
//        pipeline.addLast(new ObjectEncoder());
//        pipeline.addLast(new ObjectDecoder(new ClassResolver() {
//            @Override
//            public Class<?> resolve(String className) throws ClassNotFoundException {
//                return Class.forName(className);
//            }
//        }));
//        pipeline.addLast("ping", new IdleStateHandler(3, 3, 3, TimeUnit.SECONDS));
        pipeline.addLast(new MyDecode());
        pipeline.addLast(new MyEncode(new JsonSerializer()));
        pipeline.addLast(new NettyRPCClientHandler());
    }
}

package client.netty;

import common.RPCResponse;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.AttributeKey;

public class NettyRPCClientHandler extends SimpleChannelInboundHandler<RPCResponse> {
    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, RPCResponse rpcResponse) throws Exception {
        AttributeKey<RPCResponse> key = AttributeKey.valueOf("RPCResponse");
        channelHandlerContext.channel().attr(key).set(rpcResponse);
        channelHandlerContext.channel().close();
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        ctx.close();
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if(evt instanceof IdleStateEvent){
            IdleStateEvent event = (IdleStateEvent) evt;
            if(event.state() == IdleState.WRITER_IDLE){
                System.out.println("写空闲");
                ctx.writeAndFlush("heart beat");
            }
        }else if(evt instanceof ChannelInputShutdownEvent){
            System.out.println("连接关闭");
        }else{
            super.userEventTriggered(ctx, evt);
        }
    }
}

服务端

package server;

public interface RPCServer {
    void start(int port);
    void stop();
}

package server.netty;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.AllArgsConstructor;
import server.RPCServer;
import server.ServiceProvider;

@AllArgsConstructor
public class NettyServer implements RPCServer {
    private ServiceProvider serviceProvider;
    @Override
    public void start(int port) {
        NioEventLoopGroup bossGroup = new NioEventLoopGroup();
        NioEventLoopGroup workGroup = new NioEventLoopGroup();
        System.out.println("netty服务器启动=====>");
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new NettyServerInitializer(serviceProvider));
            ChannelFuture channelFuture = bootstrap.bind(port).sync();
            channelFuture.channel().closeFuture().sync();
        }catch (InterruptedException e) {
            e.printStackTrace();
        }finally {
            bossGroup.shutdownGracefully();
            workGroup.shutdownGracefully();
        }
    }

    @Override
    public void stop() {

    }
}

package server.netty;

import common.RPCRequest;
import common.RPCResponse;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.AllArgsConstructor;
import server.ServiceProvider;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

@AllArgsConstructor
public class NettyRPCServerHandler extends SimpleChannelInboundHandler<RPCRequest> {
    private ServiceProvider serviceProvider;
    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, RPCRequest request) throws Exception {
        RPCResponse response = getResponse(request);
        channelHandlerContext.writeAndFlush(response);
//        channelHandlerContext.close();
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        ctx.close();
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if(evt instanceof IdleStateEvent){
            IdleStateEvent event = (IdleStateEvent) evt;
            if(event.state() == IdleState.READER_IDLE){
                System.out.println("server已经5秒没读到数据");
//                if(!ctx.channel().isActive()){
//                    ctx.channel().close();
//                }
            }
        }else{
            super.userEventTriggered(ctx, evt);
        }
    }

    RPCResponse getResponse(RPCRequest request){
        try {
            Object service = serviceProvider.getService(request.getInterfaceName());
            Method method = service.getClass().getMethod(request.getMethodName(), request.getParamsTypes());
            Object data = method.invoke(service, request.getParams());
            return RPCResponse.success(data);
        } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
            e.printStackTrace();
            return RPCResponse.error();
        }
    }
}

package server.netty;

import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import lombok.AllArgsConstructor;
import serializer.JsonSerializer;
import serializer.MyDecode;
import serializer.MyEncode;
import server.ServiceProvider;

@AllArgsConstructor
public class NettyServerInitializer extends ChannelInitializer<SocketChannel> {
    private ServiceProvider serviceProvider;

    @Override
    protected void initChannel(SocketChannel socketChannel) throws Exception {
        ChannelPipeline pipeline = socketChannel.pipeline();
//        pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
//        pipeline.addLast(new LengthFieldPrepender(4));
//
//        pipeline.addLast(new ObjectEncoder());
//        pipeline.addLast(new ObjectDecoder(new ClassResolver() {
//            @Override
//            public Class<?> resolve(String className) throws ClassNotFoundException {
//                return Class.forName(className);
//            }
//        }));
        pipeline.addLast(new MyDecode());
        pipeline.addLast(new MyEncode(new JsonSerializer()));
        pipeline.addLast(new IdleStateHandler(5, 0, 0));
        pipeline.addLast(new NettyRPCServerHandler(serviceProvider));
    }
}

package server;

import register.ServiceRegister;
import register.ZkServiceRegister;

import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;

public class ServiceProvider {
    private Map<String, Object> interfaceProvider;
    private ServiceRegister serviceRegister;
    private String ip;
    private Integer port;

    public ServiceProvider(String ip, Integer port){
        this.ip = ip;
        this.port = port;
        interfaceProvider = new HashMap<>();
        serviceRegister = new ZkServiceRegister();
    }

    public void addService(Object service){
        Class<?>[] interfaces = service.getClass().getInterfaces();
        for (Class<?> aClass : interfaces) {
            interfaceProvider.put(aClass.getName(), service);
            serviceRegister.register(aClass.getName(), new InetSocketAddress(ip, port));
        }
    }

    public Object getService(String interfaceName){
        return interfaceProvider.get(interfaceName);
    }
}

注册中心

package register;

import java.net.InetSocketAddress;

public interface ServiceRegister {
    void register(String serviceName, InetSocketAddress address);
    InetSocketAddress serviceDiscovery(String serviceName);
}

package register;

import loadbalance.ConsistentHashLoadBalance;
import loadbalance.LoadBalance;
import loadbalance.RandomLoadBalance;
import org.apache.curator.RetryPolicy;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.CuratorFrameworkFactory;
import org.apache.curator.retry.ExponentialBackoffRetry;
import org.apache.zookeeper.CreateMode;

import java.net.InetSocketAddress;
import java.util.List;

public class ZkServiceRegister implements ServiceRegister{
    private CuratorFramework client;
    private static final String ROOT_PATH = "EgoRPC";
    private LoadBalance loadBalance;

    public ZkServiceRegister(){
        loadBalance = new ConsistentHashLoadBalance();
        RetryPolicy policy = new ExponentialBackoffRetry(1000, 3);
        this.client = CuratorFrameworkFactory.builder()
                .connectString("127.0.0.1:2181")
                .sessionTimeoutMs(40000)
                .retryPolicy(policy)
                .namespace(ROOT_PATH)
                .build();
        this.client.start();
        System.out.println("zookeeper 连接成功");
    }
    @Override
    public void register(String serviceName, InetSocketAddress address) {
        try {
            if(client.checkExists().forPath("/" + serviceName) == null){
                client.create().creatingParentsIfNeeded().withMode(CreateMode.PERSISTENT).forPath("/" + serviceName);
            }
            String path = "/" + serviceName + "/" + getServiceAddress(address);
            client.create().creatingParentsIfNeeded().withMode(CreateMode.EPHEMERAL).forPath(path);
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("此服务已存在");
        }
    }

    @Override
    public InetSocketAddress serviceDiscovery(String serviceName) {
        try {
            List<String> paths = client.getChildren().forPath("/" + serviceName);
            String path = loadBalance.balance(paths, serviceName);
            return parseAddress(path);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }
    private String getServiceAddress(InetSocketAddress address){
        return address.getHostName() + ":" + address.getPort();
    }
    private InetSocketAddress parseAddress(String address){
        String[] split = address.split(":");
        return new InetSocketAddress(split[0], Integer.parseInt(split[1]));
    }
}

编码解码器&&序列化器

package serializer;

import common.RPCRequest;
import common.RPCResponse;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import lombok.AllArgsConstructor;

@AllArgsConstructor
public class MyEncode extends MessageToByteEncoder {
    private Serializer serializer;
    @Override
    protected void encode(ChannelHandlerContext channelHandlerContext, Object o, ByteBuf byteBuf) throws Exception {
        if(o instanceof RPCRequest){
            byteBuf.writeShort(MessageType.REQUEST.getCode());
        }else if(o instanceof RPCResponse){
            byteBuf.writeShort(MessageType.RESPONSE.getCode());
        }

        byteBuf.writeShort(serializer.getType());
        byte[] bytes = serializer.serialize(o);
        byteBuf.writeInt(bytes.length);
        byteBuf.writeBytes(bytes);
    }
}

package serializer;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;

import java.util.List;

public class MyDecode extends ByteToMessageDecoder {
    @Override
    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
        short messageType = byteBuf.readShort();
        if(messageType != MessageType.REQUEST.getCode() && messageType != MessageType.RESPONSE.getCode()){
            throw new RuntimeException("暂不支持此类型的数据");
        }
        short serializeType = byteBuf.readShort();
        Serializer serializer = Serializer.getSerializer(serializeType);
        if(serializer == null){
            throw new RuntimeException("不存在对应的序列化器");
        }
        int length = byteBuf.readInt();
        byte[] bytes = new byte[length];
        byteBuf.readBytes(bytes);
        Object obj = serializer.deserialize(bytes, messageType);
        list.add(obj);
    }
}

package serializer;

public interface Serializer {
    byte[] serialize(Object obj);
    Object deserialize(byte[] bytes, int messageType);
    int getType();

    static Serializer getSerializer(int serializeType){
        switch (serializeType){
            case 0:
                return new ObjectSerializer();
            case 1:
                return new JsonSerializer();
            default:
                return null;
        }

    }
}

package serializer;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import common.RPCRequest;
import common.RPCResponse;

public class JsonSerializer implements Serializer{
    @Override
    public byte[] serialize(Object obj) {
        return JSONObject.toJSONBytes(obj);
    }

    @Override
    public Object deserialize(byte[] bytes, int messageType) {
        Object obj = null;
        switch (messageType){
            case 0:
                RPCRequest request = JSON.parseObject(bytes, RPCRequest.class);
                Object[] objects = new Object[request.getParams().length];
                for(int i = 0; i < objects.length; i++){
                    Class<?> paramsType = request.getParamsTypes()[i];
                    if(!paramsType.isAssignableFrom(request.getParams()[i].getClass())){
                        objects[i] = JSONObject.toJavaObject((JSONObject)request.getParams()[i], request.getParamsTypes()[i]);
                    }else {
                        objects[i] = request.getParams()[i];
                    }
                }
                request.setParams(objects);
                obj = request;
                break;
            case 1:
                RPCResponse response = JSON.parseObject(bytes, RPCResponse.class);
                Class<?> dataType = response.getDataType();
                if(!dataType.isAssignableFrom(response.getData().getClass())){
                    response.setData(JSONObject.toJavaObject((JSONObject)response.getData(), dataType));
                }
                obj = response;
                break;
            default:
                throw new RuntimeException("暂不支持此类型的数据");
        }
        return obj;
    }

    @Override
    public int getType() {
        return 1;
    }
}

package serializer;

import java.io.*;

public class ObjectSerializer implements Serializer{
    @Override
    public byte[] serialize(Object obj) {
        try {
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(bos);
            oos.writeObject(obj);
            oos.flush();
            byte[] bytes = bos.toByteArray();
            oos.close();
            bos.close();
            return bytes;
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public Object deserialize(byte[] bytes, int messageType) {
        try {
            ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
            ObjectInputStream ois = new ObjectInputStream(bis);
            Object obj = ois.readObject();
            ois.close();
            bis.close();
            return obj;
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }

    @Override
    public int getType() {
        return 0;
    }
}

负载均衡

package serializer;

import java.io.*;

public class ObjectSerializer implements Serializer{
    @Override
    public byte[] serialize(Object obj) {
        try {
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(bos);
            oos.writeObject(obj);
            oos.flush();
            byte[] bytes = bos.toByteArray();
            oos.close();
            bos.close();
            return bytes;
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public Object deserialize(byte[] bytes, int messageType) {
        try {
            ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
            ObjectInputStream ois = new ObjectInputStream(bis);
            Object obj = ois.readObject();
            ois.close();
            bis.close();
            return obj;
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }

    @Override
    public int getType() {
        return 0;
    }
}

package loadbalance;

import java.util.List;
import java.util.Random;

public class RandomLoadBalance implements LoadBalance{
    @Override
    public String balance(List<String> addressList, String serviceName) {
        Random random = new Random();
        int choose = random.nextInt(addressList.size());
        return addressList.get(choose);
    }
}

package loadbalance;

import java.util.List;

public class RoundLoadBalance implements LoadBalance{
    private int choose = -1;
    @Override
    public String balance(List<String> addressList, String serviceName) {
        choose++;
        choose = choose % addressList.size();
        return addressList.get(choose);
    }
}

package loadbalance;

import java.util.*;

public class ConsistentHashLoadBalance implements LoadBalance{
    private static SortedMap<Integer, String> hashRing;
    private static Map<Integer, List<Integer>> real2Virtual;
    private static final int virtualNum = 100;

    @Override
    public String balance(List<String> addressList, String serviceName) {
        if(hashRing == null){
            hashRing = new TreeMap<>();
            real2Virtual = new HashMap<>();
            for (String address : addressList) {
                addNode(address);
            }
        }
        int hash = getHash(serviceName);
        SortedMap<Integer, String> subMap = hashRing.tailMap(hash);
        if(subMap.isEmpty()){
            return hashRing.get(hashRing.firstKey());
        }
        return subMap.get(subMap.firstKey());
    }

    private void addNode(String address){
        int hash = getHash(address);
        hashRing.put(hash, address);
        List<Integer> list = new ArrayList<>();
        for(int i = 1; i <= virtualNum; i++){
            String virtualNode = address + "-" + i;
            int virtualHash = getHash(virtualNode);
            hashRing.put(virtualHash, address);
            list.add(virtualHash);
        }
        real2Virtual.put(hash, list);
    }
    private void deleteNode(String address){
        int hash = getHash(address);
        hashRing.remove(hash);
        for (int virtualHash : real2Virtual.get(hash)) {
            hashRing.remove(virtualHash);
        }
        real2Virtual.remove(hash);
    }
    private static int getHash(String str) {
        final int p = 16777619;
        int hash = (int) 2166136261L;
        for (int i = 0; i < str.length(); i++) {
            hash = (hash ^ str.charAt(i)) * p;
            hash += hash << 13;
            hash ^= hash >> 7;
            hash += hash << 3;
            hash ^= hash >> 17;
            hash += hash << 5;
            if(hash < 0) {
                hash = Math.abs(hash);
            }
        }
        return hash;
    }

}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值