手写一个简单rpc框架(二)

        扑街前言:继续上篇文章的内容,上篇文章说了rpc框架的架构和服务端的代码编写,那么本篇实例客户端的代码编写。(认识到自己是菜鸟的51天)


        老规矩先说下客户端的大致内容和流程,1、同服务端一样,需要一个引导类用于运行启动器;2、一个启动器,进行服务的发现,从zk中将所有服务信息拉取下来,存入缓存,同时监听每个接口;3、服务发现,用于拉取zookeeper中的注册信息,并放入缓存,如果缓存中有,那么直接取缓存中的值;4、动态代理,用于对远程接口的简单实例;5、将请求的信息封装为RPC定义的请求;6、请求管理器,用于调用对应请求;7、客户端一、二次编解码,同服务端原理一致,但是需要注意二次编解码的对象与服务端不一样;8、RpcResponseHandler,用于对服务端返回的数据进行处理;9、异步获取结果,对于RpcResponseHandler的处理返回进行异步操作;10、连接复用,建立长连接;11、负载均衡,对于集群的管理。(这里其实还有断熔之类的组件,后续关于dubbo 的文章再详细说)


代码示例

        那么下面就开始正式的代码流程。

引导类

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;

import javax.annotation.PostConstruct;

@Configuration
public class RpcBootstrap {
    
    @Autowired
    private RpcClientRunner rpcClientRunner;
    
    @PostConstruct
    public void init () {
        rpcClientRunner.run();
    }
}

启动器

import org.springframework.stereotype.Component;

@Component
public class RpcClientRunner {

    public void run () {
        // 进行服务发现,拉取信息进缓存,并生成动态代理

    }
}

服务发现

        上面基本的准备工作完成之后,下面开始服务发现的代码内容。

客户端的zookeeper封装

import com.google.common.collect.Lists;
import com.rpc.cache.ServiceProviderCache;
import com.rpc.client.config.RpcClientConfiguration;
import com.rpc.provider.ServiceProvider;
import org.I0Itec.zkclient.IZkChildListener;
import org.I0Itec.zkclient.ZkClient;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.stream.Collectors;

@Component
public class ClientZKit {

    @Autowired
    private RpcClientConfiguration configuration;

    @Autowired
    private ZkClient zkClient;

    @Autowired
    private ServiceProviderCache cache;

    /**
     * 服务订阅接口
     * @param serviceName
     */
    public void subscribeZKEvent(String serviceName) {
        // 1. 组装服务节点信息
        String path = configuration.getZkRoot() + "/" + serviceName;
        // 2. 订阅服务节点(监听节点变化)
        zkClient.subscribeChildChanges(path, new IZkChildListener() {
            @Override
            public void handleChildChange(String parentPath, List<String> list) throws Exception {
                // 3. 判断获取的节点信息,是否为空
                if (CollectionUtils.isNotEmpty(list)) {
                    // 4. 将服务端获取的信息, 转换为服务记录对象
                    List<ServiceProvider> providerServices = convertToProviderService(serviceName, list);
                    // 5. 更新缓存信息 记得要改
                    cache.update(serviceName,providerServices);
                }
            }
        });
    }


    /**
     * 获取所有服务列表:所有的服务接口信息
     * @return
     */
    public List<String> getServiceList() {
        String path = configuration.getZkRoot();
        List<String> children = zkClient.getChildren(path);
        return children;
    }

    /**
     *  根据服务名称获取服务节点完整信息
     * @param serviceName
     * @return
     */
    public List<ServiceProvider> getServiceInfos(String serviceName) {
        String path = configuration.getZkRoot() + "/" + serviceName;
        List<String> children = zkClient.getChildren(path);
        List<ServiceProvider> providerServices = convertToProviderService(serviceName,children);
        return providerServices;
    }

    /**
     * 将拉取的服务节点信息转换为服务记录对象
     *
     * @param serviceName
     * @param list
     * @return
     */
    private List<ServiceProvider> convertToProviderService(String serviceName, List<String> list) {
        if (CollectionUtils.isEmpty(list)) {
            return Lists.newArrayListWithCapacity(0);
        }
        // 将服务节点信息转换为服务记录对象
        List<ServiceProvider> providerServices = list.stream().map(v -> {
            String[] serviceInfos = v.split(":");
            return ServiceProvider.builder()
                    .serviceName(serviceName)
                    .serverIp(serviceInfos[0])
                    .rpcPort(Integer.parseInt(serviceInfos[1]))
                    .build();
        }).collect(Collectors.toList());
        return providerServices;
    }
}

连接缓存的接口

import com.itheima.rpc.provider.ServiceProvider;

import java.util.List;

/**
 * @description
 * @author: ts
 * @create:2021-05-11 15:26
 */
public interface ServiceProviderCache {
    /**
     * 向缓存中添加数据
     * @param key
     * @param value
     */
    void put(String key, List<ServiceProvider> value);

    /**
     * 获取缓存
     * @param key
     * @return
     */
    List<ServiceProvider> get(String key);

    /**
     * 缓存清除
     * @param key
     */
    void evict(String key);


    /**
     * 缓存更新
     * @param key
     * @param value
     */
    void update(String key,List<ServiceProvider> value);
}

缓存连接实现类

import com.google.common.cache.LoadingCache;
import com.google.common.collect.Lists;
import com.rpc.cache.ServiceProviderCache;
import com.rpc.provider.ServiceProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.concurrent.ExecutionException;

/**
 * @description
 * @author: ts
 * @create:2021-05-11 15:28
 */
@Component
public class DefaultServiceProviderCache implements ServiceProviderCache {

    @Autowired
    private LoadingCache<String, List<ServiceProvider>> cache;

    @Override
    public void put(String key, List<ServiceProvider> value) {
        cache.put(key,value);
    }

    @Override
    public List<ServiceProvider> get(String key) {
        try {
            return cache.get(key);
        } catch (ExecutionException e) {
            return Lists.newArrayListWithCapacity(0);
        }
    }

    @Override
    public void evict(String key) {
        cache.invalidate(key);
    }

    @Override
    public void update(String key, List<ServiceProvider> value) {
        evict(key);
        put(key,value);
    }
}

缓存服务的提供者

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.io.Serializable;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ServiceProvider implements Serializable {
    private String serviceName;
    private String serverIp;
    private int rpcPort;
    private int networkPort;
    private long timeout;
    // the weight of service provider
    private int weight;
}

服务发现接口

public interface RpcServiceDiscovery {
    /**
     * 完成服务发现逻辑
     */
    void serviceDiscovery();
}

服务发现的zookeeper实现

import com.rpc.cache.ServiceProviderCache;
import com.rpc.client.discovery.RpcServiceDiscovery;
import com.rpc.provider.ServiceProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;

@Component
public class ZkServiceDiscovery implements RpcServiceDiscovery {
    @Autowired
    private ServiceProviderCache cache;

    @Autowired
    private ClientZKit clientZKit;

    @Override
    public void serviceDiscovery() {
        /*
         * 获取配置文件中父节点下的所有下级子节点
         * 根据每一个子节点获取它们对应的服务信息
         * 存入缓存
         * 订阅监听该子节点的服务变更
         */
        // 获取配置文件中父节点下的所有下级子节点
        List<String> serviceList = clientZKit.getServiceList();

        // 如果不存在节点信息直接结束
        if (serviceList == null || serviceList.size() <= 0) {
            return;
        }

        // 获取集合大小
        int size = serviceList.size();
        // 循环遍历
        for (int i = 0; i < size; i++) {
            // 获取到相关子节点
            String service = serviceList.get(i);

            // 获取子节点信息
            List<ServiceProvider> serviceInfos = clientZKit.getServiceInfos(service);

            // 存入缓存
            cache.put(service, serviceInfos);

            // 订阅该子节点的服务变更信息
            clientZKit.subscribeZKEvent(service);
        }
    }
}

动态代理

        先说下为什么要有动态代理,客户端调用远程接口,也就是调动服务端接口的时候,是没有相应的实现类在客户端的代码中的,那也就意味着当接口被注入IOC容器的时候是没有具体实现的,所以需要有动态代理来封装调用远程的请求。

        这里还要提一个由Spring 提供的接口 BeanPostProcessor ,这个就是为了在bean对象初始化之前或者之后,进行拦截并执行自定义的方法内容。至于决定是初始化之前还是之后执行具体内容,是由BeanPostProcessor 接口提供的postProcessBeforeInitialization(初始化之前)和postProcessAfterInitialization (初始化之后)方法实现的。

对象初始化后拦截

import com.rpc.annotation.HrpcRemote;
import com.rpc.client.proxy.RequestProxyFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;

import java.lang.reflect.Field;

@Slf4j
public class RpcAnnotationProcessor implements BeanPostProcessor, ApplicationContextAware {

    private RequestProxyFactory proxyFactory;

    /**
     * 初始化之后的方法
     *
     * @param bean     容器中的bean
     * @param beanName bean的名称
     * @return 返回的也是bean
     * @throws BeansException
     */
    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        /*
         * 获取bean的所有成员属性
         * 获取成员属性的注解
         * 判断是否获取到了指定注解
         * 否则跳过,是则为注解修饰的成员属性添加代理
         * 添加代理人,首先需要获取成员属性的类型(也就是接口,比如String a,a就是成员属性,String就a的类型),然后通过自定义的对象封装,然后回写给bean
         */
        // 获取bean 上的所有私有方法(也就是private)
        Field[] fields = bean.getClass().getDeclaredFields();

        // 循环所有私有方法
        for (Field field : fields) {
            // 如果属性不可访问,则设置为可访问
            if (!field.isAccessible()) {
                // 设置可访问
                field.setAccessible(true);
            }

            // 获取每一个方法上的HrpcRemote 注解,如果该方法没有这个注解,返回则是null
            HrpcRemote hrpcRemote = field.getAnnotation(HrpcRemote.class);

            // 如果没有该注解修饰,则不是远程接口
            if (hrpcRemote == null) {
                // 跳出本次循环
                continue;
            }

            // 获取到成员属性的类型,也就是接口的类
            Class<?> type = field.getType();

            // 为该接口生成自定义的动态代理(注意这里的动态代理生成是用的 cglib)
            Object proxyInstance = proxyFactory.newProxyInstance(type);

            // 如果没有生成动态代理,则跳出
            if (proxyInstance == null) {
                // 跳出本次循环
                continue;
            }

            try {
                log.info("为{}生成的代理为:{}",type,proxyInstance);
                // 为成员数据注入动态代理,并添加到bean上
                field.set(bean, proxyInstance);
            } catch (Exception e) {
                log.error("proxyInstance exception ,msg={}",e.getMessage());
            }
        }

        // 返回bean
        return bean;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        // 从容器中获取动态代理接口的注册bean
        this.proxyFactory = applicationContext.getBean(RequestProxyFactory.class);
    }
}

Cglib动态代理

import com.rpc.proxy.ProxyFactory;
import lombok.extern.slf4j.Slf4j;
import net.sf.cglib.proxy.Enhancer;
import org.springframework.stereotype.Component;


/**
 * @description
 * @author: ts
 * @create:2021-05-10 09:42
 */
@Component
@Slf4j
public class RequestProxyFactory implements ProxyFactory{

    /**
     * 创建新的代理实例-CGLib动态代理
     * @param cls
     * @param <T>
     * @return
     */
    public  <T> T newProxyInstance(Class<T> cls) {
        Enhancer enhancer = new Enhancer();
        enhancer.setSuperclass(cls);
        enhancer.setCallback(new CglibProxyCallBackHandler());
        return (T) enhancer.create();
    }
}

import com.rpc.client.request.RpcRequestManager;
import com.rpc.data.RpcRequest;
import com.rpc.data.RpcResponse;
import com.rpc.exception.RpcException;
import com.rpc.spring.SpringBeanFactory;
import com.rpc.util.RequestIdUtil;
import net.sf.cglib.proxy.MethodInterceptor;
import net.sf.cglib.proxy.MethodProxy;
import org.springframework.context.ApplicationContext;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Method;

/**
 * @description
 * @author: ts
 * @create:2021-05-12 00:11
 */
public class CglibProxyCallBackHandler implements MethodInterceptor {


    public Object intercept(Object o, Method method, Object[] parameters, MethodProxy methodProxy) throws Throwable {
        /*
         * 首先使用ReflectionUtils 来过滤掉所有来自object方法,将这些方法进行初始化调用
         * 然后再获取方法对应的接口名称,获取方法的名称,获取方法的参数类型
         * 然后构造request 对象
         * 从容器中获取netty调用的封装
         * 调用netty 传入request,返回response
         * 返回response
         */
        // 判断该方式是不是来自object的方法
        if (ReflectionUtils.isObjectMethod(method)) {
            // 是则方法进行调整,传入方法所在类的实例和方法的传入参数
            return method.invoke(method.getDeclaringClass().newInstance(), parameters);
        }

        // 获取方法对应的类名
        String interfaceName = method.getDeclaringClass().getName();
        // 获取方法名称
        String methodName = method.getName();
        // 获取方法参数类型
        Class<?>[] parameterTypes = method.getParameterTypes();

        // 获取当前请求的主键
        String requestId = RequestIdUtil.requestId();

        // 封装请求对象
        RpcRequest rpcRequest = RpcRequest.builder()
                .className(interfaceName)
                .methodName(methodName)
                .parameterTypes(parameterTypes)
                .parameters(parameters)
                .requestId(requestId)
                .build();

        // 获取netty调用的封装
        RpcRequestManager rpcRequestManager = SpringBeanFactory.getBean(RpcRequestManager.class);

        if (rpcRequestManager == null) {
            throw new RpcException("spring ioc exception");
        }

        RpcResponse response = rpcRequestManager.sendRequest(rpcRequest);

        // 返回结果
        return response;
    }
}

基于netty的网络连接

import com.rpc.cache.ServiceProviderCache;
import com.rpc.client.cluster.LoadBalanceStrategy;
import com.rpc.client.cluster.StartegyProvider;
import com.rpc.data.RpcRequest;
import com.rpc.data.RpcResponse;
import com.rpc.netty.codec.*;
import com.rpc.netty.handler.RpcResponseHandler;
import com.rpc.netty.request.ChannelMapping;
import com.rpc.netty.request.RequestPromise;
import com.rpc.netty.request.RpcRequestHolder;
import com.rpc.provider.ServiceProvider;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;

@Component
@Slf4j
public class RpcRequestManager {
    @Autowired
    private ServiceProviderCache cache;

    @Autowired
    private StartegyProvider provider;

    public RpcResponse sendRequest (RpcRequest request) {
        /*
         * 选择负载均衡的策略对象,根据配置信息
         * 然后从缓存中获取完整的服务列表
         * 然后根据负载均衡对象从服务列表中选出一个服务
         * 然后通过这个服务基于netty获取连接,请求访问
         */
        // 首先获取服务对象集合,如果没有服务,那么直接结束
        List<ServiceProvider> serviceProviders = cache.get(request.getClassName());
        // 服务列表不能为空
        if (serviceProviders == null && serviceProviders.size() <= 0){
            // 返回为空的响应对象
            return new RpcResponse();
        }

        // 获取到策略对象
        LoadBalanceStrategy strategy = provider.getStrategy();
        // 根据策略获取具体的服务
        ServiceProvider serviceProvider = strategy.select(serviceProviders);

        return null;
    }

    private RpcResponse requestByNetty (ServiceProvider serviceProvider, RpcRequest request) {
        try {
            /*
             * 因为需要长连接,不能每一次请求就连接一下服务端,所以用一个自定义的工具类来实现 RpcRequestHolder
             * 如果当前连接,也就是ip和端口有在RpcRequestHolder 对象的属性中时,就不再连接一次了
             * 当netty 连接完成之后,使用channel 对服务器进行传输,并获取返回对象
             * 结束返回
             */
            // 获取ip和端口
            String ip = serviceProvider.getServerIp();
            int port = serviceProvider.getRpcPort();

            // 判断当前ip和端口是否在map中存在
            if (!RpcRequestHolder.channelExist(ip, port)) {
                // 建立连接
                this.nettyConnect(ip, port);
            }

            // 获取到连接成功后的channel,只有第一次请求的时候才会去创建channel
            Channel channel = RpcRequestHolder.getChannel(ip, port);

            // 获取Promise 用于阻塞对消息的传输时的主线程
            RequestPromise promise = new RequestPromise(channel.eventLoop());

            // 将Promise 存入全局映射,用于当消息传输、阻塞时的第二次请求处理,避免获取到对应的返回对象
            RpcRequestHolder.addRequestPromise(request.getRequestId(), promise);

            // 对数据进行传输
            channel.writeAndFlush(request);

            // 进行阻塞
            RpcResponse response = (RpcResponse) promise.get();

            // 结束,返回
            return response;
        } catch (Exception e) {
            log.error("rpc request exception,msg={}",e.getMessage());
        }
        return new RpcResponse();
    }

    private void nettyConnect (String ip, int port) throws Exception {
        try {
            /*
             * 先创建一个请求发起的线程池,
             * 然后构建引导类
             * 然后启动引导类
             * 最后监听线程启动,并为工具类中的map添加当前ip端口对应的channel,用于后面的操作使用
             */
            // 创建基于Nio实现的线程池
            NioEventLoopGroup worker = new NioEventLoopGroup(0, new DefaultThreadFactory("worker"));

            // 创建inbuond对应的 Handler
            RpcResponseHandler rpcResponseHandler = new RpcResponseHandler();

            // 创建引导类
            Bootstrap bootstrap = new Bootstrap();

            // 构建引导类
            bootstrap.group(worker)
                    .channel(NioServerSocketChannel.class)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel) throws Exception {
                            ChannelPipeline pipeline = socketChannel.pipeline();
                            // 一次编码
                            pipeline.addLast(new FrameEncoder());
                            // 二次编码
                            pipeline.addLast(new RpcRequestEncoder());

                            // 一次解码
                            pipeline.addLast(new FrameDecoder());
                            // 二次解码
                            pipeline.addLast(new RpcResponseDecoder());

                            // response阻塞获取
                            pipeline.addLast(rpcResponseHandler);
                        }
                    });

            // 同步启动引导类,根据ip端口连接服务端,并监听
            ChannelFuture future = bootstrap.connect(ip, port).sync();

            // 当启动连接成功
            if (future.isSuccess()){
                // 获取channel
                Channel channel = future.channel();

                // 将channel放入工具类的map里面,做为全局映射
                RpcRequestHolder.addChannelMapping(new ChannelMapping(ip, port, channel));
            }
        } catch (Exception e) {
            throw new Exception(e);
        }
    }
}

response返回对象的接收

import com.rpc.data.RpcResponse;
import com.rpc.netty.request.RequestPromise;
import com.rpc.netty.request.RpcRequestHolder;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;

@Slf4j
@ChannelHandler.Sharable
public class RpcResponseHandler extends SimpleChannelInboundHandler<RpcResponse> {
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception {
        log.info("客户端收到结果为:{}",response);
        RequestPromise requestPromise = RpcRequestHolder.getRequestPromise(response.getRequestId());
        if (requestPromise!=null) {
            //通知结束阻塞
            requestPromise.setSuccess(response);
        }
    }
}

        上述就是手写rpc框架的全部内容了,客户端的一次、二次编码和一次、二次解码,就不再多说了,原理和服务端的是一样的,区别就是二次编码和解码的对象不一样,一个是请求、一个是响应。后续文章再讨论dubbo,终于结束了。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值