基于netty的RPC(模拟实现远程服务调用)

2 篇文章 0 订阅
package rpc;


import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import org.junit.Test;

import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author x y
 * @description 基于netty的rpc框架
 * rpc:远程服务调用,简单来说就是像调用本地方法一样调用远程方法
 * <p>
 * 2022年3月21日 调整  编解码问题
 * 2022年3月22日 调整 客户端接收以及处理返回结果
 * 2022年3月31日 解决因线程争抢堆共享变量造成的线程安全问题
 * @date 2022-03-16 10:08
 */
public class MyRPC {

    @Test
    public void provider() {
        NioEventLoopGroup boos = new NioEventLoopGroup();
        NioEventLoopGroup works = boos;
        ServerBootstrap bootstrap = new ServerBootstrap();
        ChannelFuture future = bootstrap.group(boos, works)
                .channel(NioServerSocketChannel.class)
                .childHandler(new ChannelInitializer<NioSocketChannel>() {
                    @Override
                    protected void initChannel(NioSocketChannel ch) throws Exception {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new ServerEncodeHandler());
                        pipeline.addLast(new ServerResponseHandler());
                    }
                })
                .bind(new InetSocketAddress("localhost", 9003));
        try {
            NioServerSocketChannel server = (NioServerSocketChannel) future.sync().channel();
            server.closeFuture().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    // consumer端
    @Test
    public void consumer() {
        CarImpl carImpl = new CarImpl();
        Dispatcher.register(Car.class.getName(),carImpl);
        new Thread(() -> {
            provider();
        }).start();
        System.err.println("server start....");
        // 根据动态代理获取对应对象
        AtomicInteger integer = new AtomicInteger(0);
        for (int i = 0; i < 20; i++) {
            new Thread(() -> {
                Car car = proxyCreateInterface(Car.class);
                String str = "param :" + integer.incrementAndGet();
                System.err.println(str + "client response " + car.carName(str));
            }).start();
        }
        try {
            System.in.read();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    // 可扩展性
    public static <T> T proxyCreateInterface(Class<T> interfaceInfo) {
        // 通过动态代理获取对象
        ClassLoader classLoader = interfaceInfo.getClassLoader();
        Class<?>[] methodInfo = {interfaceInfo};

        return (T) Proxy.newProxyInstance(classLoader, methodInfo, new InvocationHandler() {
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

                //1.调用 服务,方法名称,参数 --> messageBody
                String className = interfaceInfo.getName();
                String methodName = method.getName();
                Class<?>[] parameterTypes = method.getParameterTypes();
                MyContent messageBody = MyContent.valueOf(className, methodName, parameterTypes, args);

                // 序列化转为字节数组
                ByteArrayOutputStream out = new ByteArrayOutputStream();
                ObjectOutputStream oout = new ObjectOutputStream(out);
                oout.writeObject(messageBody);
                byte[] messageArray = out.toByteArray();

                //2.requestId+messageBody
                MyHeader header = createMyHeader(messageArray);
                out.reset();
                oout = new ObjectOutputStream(out);
                oout.writeObject(header);
                byte[] headerArray = out.toByteArray();
                System.err.println("header size:" + headerArray.length);
                CountDownLatch downLatch = new CountDownLatch(1);
                long requestId = header.getRequestId();
                CompletableFuture<String> completableFuture = new CompletableFuture<>();
                ClientCache.addMap(requestId, completableFuture);
                //3.线程池 获取链接
                ClientFactory clientFactory = ClientFactory.getInstance();
                NioSocketChannel client = clientFactory.getNioSocketChannel(new InetSocketAddress("localhost", 9003));
                //4.消息发送IO -> 走netty
                ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(messageArray.length + headerArray.length);
                buffer.writeBytes(headerArray);
                buffer.writeBytes(messageArray);
                ChannelFuture future = client.writeAndFlush(buffer);
                future.sync();//等待
                //5. 消息回调如何处理 阻塞
//               completableFuture.complete("写值 get可以及时获取");
                return completableFuture.get();
            }

            private MyHeader createMyHeader(byte[] messageArray) {
                MyHeader myHeader = new MyHeader();
                int size = messageArray.length;
                myHeader.setDataLength(size);
                int f = 0x14141414; //四个字节
                myHeader.setFlag(f);
                long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
                myHeader.setRequestId(requestId);
                return myHeader;
            }
        });
    }
}
class Dispatcher{
    private static Map<String,Object> DISPATCHER_MAP = new HashMap<>();

    public static void register(String forName,Object o){
        DISPATCHER_MAP.put(forName,o);
    }

    public static Object get(String forName){
        return DISPATCHER_MAP.get(forName);
    }
}
class CarImpl implements Car{
    @Override
    public String carName(String param) {
        return "current CarImpl response:"+param;
    }
}
class ServerBody {
    public MyHeader myHeader;
    public MyContent myContent;

    public ServerBody(MyHeader myHeader, MyContent myContent) {
        this.myHeader = myHeader;
        this.myContent = myContent;
    }

    public MyHeader getMyHeader() {

        return myHeader;
    }

    public void setMyHeader(MyHeader myHeader) {
        this.myHeader = myHeader;
    }

    public MyContent getMyContent() {
        return myContent;
    }

    public void setMyContent(MyContent myContent) {
        this.myContent = myContent;
    }
}

class ServerEncodeHandler extends ByteToMessageDecoder {
    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) throws Exception {
        while (buf.readableBytes() >= 85) {
            byte[] header = new byte[85];
            buf.getBytes(buf.readerIndex(), header);
            ByteArrayInputStream in = new ByteArrayInputStream(header);
            ObjectInputStream oin = new ObjectInputStream(in);
                MyHeader headerEntity = (MyHeader) oin.readObject();
            long dataLength = headerEntity.getDataLength();
            if (buf.readableBytes() - 85 >= dataLength) {
                byte[] dataBytes = new byte[(int) dataLength];
                buf.readBytes(85);
                buf.readBytes(dataBytes);
                ByteArrayInputStream dataIn = new ByteArrayInputStream(dataBytes);
                ObjectInputStream dataOIn = new ObjectInputStream(dataIn);
                MyContent readObject = (MyContent) dataOIn.readObject();
                if (headerEntity.getFlag() == 0x14141414) {
                    out.add(new ServerBody(headerEntity, readObject));
                } else if (headerEntity.getFlag() == 0x14141424) {
                    out.add(new ServerBody(headerEntity, readObject));
                }
            } else {
                break;
//                System.err.println("暂无操作");
            }
        }
    }
}

class ServerResponseHandler extends ChannelInboundHandlerAdapter {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        ServerBody body = (ServerBody) msg;

        // 可以接收,name如何对应返回呢 返回些什么 通信协议 0x14141424
        // 1 业务和io在一个线程处理
        // 2 业务和io分开在不同线程处理 可以自定义ThreadPoolExecutor
        // 3 使用netty封装好的线程池将业务和io处理 或将业务和io分散处理
        String currentName = Thread.currentThread().getName();
        ctx.executor().parent().execute(() -> {
            String threadName = Thread.currentThread().getName();
            // 设置请求头
            MyHeader myHeader = new MyHeader();
            myHeader.setFlag(0x14141424);
            myHeader.setRequestId(body.myHeader.getRequestId());
            // 封装消息体
            String forName = body.getMyContent().getClassName();
            String methodName = body.getMyContent().getMethodName();
            Object o = Dispatcher.get(forName);
            String s = null;
            try {
                Method method = o.getClass().getMethod(methodName, body.getMyContent().getParameterTypes());
                s = (String) method.invoke(o,body.getMyContent().getArgs());
            }catch (Exception e){

            }
            MyContent myContent = new MyContent();
            myContent.setRest(s);
            byte[] bodyBytes = ByteUtils.bytes(myContent);
            myHeader.setDataLength(bodyBytes.length);
            byte[] headerBytes = ByteUtils.bytes(myHeader);
            System.err.println("server size :" + headerBytes.length);
            ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(headerBytes.length + bodyBytes.length);
            buffer.writeBytes(headerBytes);
            buffer.writeBytes(bodyBytes);
            ctx.writeAndFlush(buffer);
        });

    }
}

class ClientCache {
    private static Map<Long, CompletableFuture<String>> RUN_MAP = new ConcurrentHashMap<>();

    public static void addMap(long requestId, CompletableFuture<String> completableFuture) {
        RUN_MAP.put(requestId, completableFuture);
    }

    public static void runMap(ServerBody body) {
        long requestId = body.getMyHeader().getRequestId();
        CompletableFuture<String> completableFuture = RUN_MAP.get(requestId);
        if (completableFuture != null) {
            completableFuture.complete(body.getMyContent().getRest());
            remove(requestId); // 防止缓存多了溢出
        }
    }

    public static void remove(long requestId) {
        RUN_MAP.remove(requestId);
    }
}

// 一个获取连接对象的工厂
class ClientFactory {
    private ClientFactory() {
        // 避免被new 实例化
    }

    Random random = new Random();
    private static volatile ClientFactory clientFactory;

    public static ClientFactory getInstance() {
        if (clientFactory == null) {
            synchronized (ClientFactory.class) {
                if (clientFactory == null) {
                    clientFactory = new ClientFactory();
                }
            }
        }
        return clientFactory;
    }

    int poolSize = 1;
    // 一个consumer 可以有多个provider
    Map<InetSocketAddress, ClientPool> clientPoolMap = new ConcurrentHashMap<>(); // cas(保持node的原子性)加synchronized

    NioEventLoopGroup clientGourp;

    public synchronized NioSocketChannel getNioSocketChannel(InetSocketAddress inetSocketAddress) {
        ClientPool clientPool = clientPoolMap.get(inetSocketAddress);
        if (clientPool == null) {
            clientPoolMap.putIfAbsent(inetSocketAddress, new ClientPool(poolSize));
            clientPool = clientPoolMap.get(inetSocketAddress);
        }
        // 获取NioSocketChannel对象 采用随机算法
        int index = random.nextInt(poolSize);
        assert clientPool != null;
        if (clientPool.client[index] != null && clientPool.client[index].isActive()) {
            return clientPool.client[index];
        }
        //初始化
        synchronized (clientPool.lock[index]) {
            // 解决第二个人获取锁
            if (clientPool.client[index] != null && clientPool.client[index].isActive()) {
                return clientPool.client[index];
            }
            return clientPool.client[index] = createNioSocketChannel(inetSocketAddress);
        }
    }


    //待完善
    private NioSocketChannel createNioSocketChannel(InetSocketAddress inetSocketAddress) {
        clientGourp = new NioEventLoopGroup(1);
        Bootstrap bootstrap = new Bootstrap();
        ChannelFuture connect = bootstrap.group(clientGourp)
                .channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<NioSocketChannel>() {
                    @Override
                    protected void initChannel(NioSocketChannel ch) throws Exception {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new ServerEncodeHandler());
                        pipeline.addLast(new ClientResponseHandler());
                    }
                }).connect(inetSocketAddress);
        try {
            return (NioSocketChannel) connect.sync().channel();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return null;
    }
}

class ClientResponseHandler extends ChannelInboundHandlerAdapter {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        ServerBody body = (ServerBody) msg;
        ClientCache.runMap(body);
    }
}

class ClientPool {
    NioSocketChannel[] client;
    Object[] lock;

    ClientPool(int size) {
        client = new NioSocketChannel[size];//无法进行初始化应该即用即创建
        lock = new Object[size]; // 可以进行初始化
        for (int i = 0; i < lock.length; i++) {
            lock[i] = new Object();
        }
    }
}

interface Car {
    String carName(String param);
}

class MyHeader implements Serializable {
    private int flag;

    private long requestId;

    private long dataLength;


    public int getFlag() {
        return flag;
    }

    public void setFlag(int flag) {
        this.flag = flag;
    }

    public long getRequestId() {
        return requestId;
    }

    public void setRequestId(long requestId) {
        this.requestId = requestId;
    }

    public long getDataLength() {
        return dataLength;
    }

    public void setDataLength(long dataLength) {
        this.dataLength = dataLength;
    }
}

class MyContent implements Serializable {
    private String className;

    private String methodName;

    private Class<?>[] parameterTypes;

    private Object[] args;

    private String rest;

    public static MyContent valueOf(String className, String methodName, Class<?>[] parameterTypes, Object[] args) {
        MyContent content = new MyContent();
        content.className = className;
        content.methodName = methodName;
        content.parameterTypes = parameterTypes;
        content.args = args;
        return content;
    }

    public String getRest() {
        return rest;
    }

    public void setRest(String rest) {
        this.rest = rest;
    }

    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class<?>[] getParameterTypes() {
        return parameterTypes;
    }

    public void setParameterTypes(Class<?>[] parameterTypes) {
        this.parameterTypes = parameterTypes;
    }

    public Object[] getArgs() {
        return args;
    }

    public void setArgs(Object[] args) {
        this.args = args;
    }
}

少提交一个工具类

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;

/**
 * @author x y
 * @description TODO
 * @date 2022-03-22 13:52
 */
public class ByteUtils {
    static ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
    private final static Object lock = new Object();

    public static byte[] bytes(Object o) {
        synchronized (lock) {
            try {
                byteArrayOutputStream.reset();
                ObjectOutputStream stream = null;
                stream = new ObjectOutputStream(byteArrayOutputStream);
                stream.writeObject(o);
                return byteArrayOutputStream.toByteArray();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return null;
        }
    }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值