DedSec RPC
一、什么是RPC
-
两个应用分别部署在A、B两台服务器,若A想要调用B运行应用所提供的服务,如何实现?
-
要解决以下问题:
- 第一,通讯问题:客户端与服务器之间建立TCP链接,远程过程调用的所有数据都在TCP链接中传输;结束后就断开,当然也可以是长连接、TCP心跳机制等;
- 第二,解决寻址的问题:客户端如何通知RPC框架,B所在的服务器IP与端口号?
- 第三,网络信息传输:如何解决A内存中的数值序列化,传递到B内存中再反序列化呢?
二、一个简单的实现
1. 通用接口
-
个人理解:暴露在注册中心的接口?客户端不存在这个接口的具体实现,但是可以调用该接口;
-
我们在api中心定义接口如下:
HelloService.java
public interface HelloService { String hello(HelloObject obj); }
-
在api中心定义传输的数据对象类型
HelloObject.java
@Data @AllArgsConstructor public class HelloObject implements Serializable { private int id; private String message; }
-
再于服务侧定义一个接口的实现:
HelloServiceImpl.java
public class HelloServiceImpl implements HelloService { private static final Logger logger = LoggerFactory.getLogger(HelloServiceImpl.class); @Override public String hello(HelloObject obj) { logger.info("接收到:{}",obj.getMessage()); return "调用返回值:当前对象ID:" + obj.getId(); } }
2. 传输协议
-
我们需要封装来自客户端的请求,这样服务端在通过网络接收到对应的请求之后,就可以反序列化得到传递的数据;
RpcRequest.java
@Data @Builder public class RpcRequest implements Serializable { /** * 调用接口名 */ private String interfaceName; /** * 待使用的方法名 */ private String methodName; /** * 调用接口的参数 */ private Object[] parameters; /** * 调用方法的参数类型 * (用字符串数组也可以) */ private Class<?>[] paramTypes; }
RpcResponse.java
@Data public class RpcResponse<T> implements Serializable { /** * 状态码 */ private Integer statusCode; /** * 状态信息 */ private String message; /** * 响应数据 */ private T data; /** * 快速生成成功对象 * @param data * @param <T> * @return */ public static <T> RpcResponse<T> success(T data){ RpcResponse<T> response = new RpcResponse<>(); response.setStatusCode(ResponseCode.SUCCESS.getCode()); response.setData(data); return response; } /** * 快速生成响应对象 * @param code * @param <T> * @return */ public static <T> RpcResponse<T> fail(ResponseCode code){ RpcResponse<T> response = new RpcResponse<>(); response.setStatusCode(code.getCode()); response.setMessage(code.getMessage()); return response; } }
3. 客户端的实现
-
客户端侧并没有接口的具体实现类,故没有办法生成实例,但可以通过动态代理的方式生成实例;
RpcClientProxy.java,代理RpcClient去实现发送RpcRequest请求;并获取本次请求返回的对象;这里的方式是JDK动态代理,需要实现InvocationHandler接口;
public class RpcClientProxy implements InvocationHandler { private String host; private int port; public RpcClientProxy(String host, int port){ this.host = host; this.port = port; } @SuppressWarnings("unchecked") public <T> T getProxy(Class<T> clazz) { return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class<?>[]{clazz}, this); } @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { //builder模式生成RpcRequest RpcRequest rpcRequest = RpcRequest.builder() .interfaceName(method.getDeclaringClass().getName()) .methodName(method.getName()) .parameters(args) .paramTypes(method.getParameterTypes()) .build(); RpcClient rpcClient = new RpcClient(); //用一个RpcClient实现发送的逻辑 return ((RpcResponse) rpcClient.sentRequest(rpcRequest,host,port)).getData(); } }
-
注意到:RpcClient.java,它实现对象的发送逻辑;
实现方式就是用Socket获取ObjectOutputStream对象,将RpcRequest发送出去即可,然后再获取ObjectInputStream对象,获得一个返回的对象;
public class RpcClient { private static final Logger logger = LoggerFactory.getLogger(RpcClient.class); public Object sentRequest(RpcRequest rpcRequest, String host, int port){ try(Socket socket = new Socket(host,port)){ ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream()); ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream()); objectOutputStream.writeObject(rpcRequest); objectOutputStream.flush(); return objectInputStream.readObject(); }catch (IOException | ClassNotFoundException e){ logger.error("调用过程中发生错误",e); return null; } } }
4. 服务端的实现
-
服务端的是现实:通过一个ServerSocket监听某个端口,循环接收连接请求,如果发来了请求就创建一个线程去处理调用,线程调度的方式采用线程池;
public class RpcServer { private final ExecutorService threadPool; private static final Logger logger = LoggerFactory.getLogger(RpcServer.class); //初始化方式 public RpcServer() { int corePoolSize = 5; int maximumPoolSize = 20; long keepAliveTime = 60; BlockingQueue<Runnable> workingQueue = new ArrayBlockingQueue<>(100); ThreadFactory threadFactory = Executors.defaultThreadFactory(); this.threadPool = new ThreadPoolExecutor( corePoolSize, maximumPoolSize, keepAliveTime, TimeUnit.SECONDS, workingQueue, threadFactory ); } //便于观察demo,这里暂时只能注册一个接口,在注册之后,立即开始监听; public void register(Object service, int port){ try (ServerSocket serverSocket = new ServerSocket(port)){ logger.info("服务器正在启动..."); Socket socket; while ((socket = serverSocket.accept()) != null){ logger.info("客户端链接,IP为:" + socket.getInetAddress()); threadPool.execute(new WorkerThread(socket,service)); } }catch (IOException e){ logger.error("连接时产生错误:",e); } } //WorkerThread即为工作的线程,用于接收RpcRequest对象,解析并且调用,生成RpcResponse对象并传输回去 class WorkerThread implements Runnable{ private Socket socket; private Object service; public WorkerThread(Socket socket, Object service) { this.socket = socket; this.service = service; } //处理过程 @Override public void run() { try (ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream()); ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream())){ RpcRequest rpcRequest = (RpcRequest) objectInputStream.readObject(); Method method = service.getClass().getMethod(rpcRequest.getMethodName(),rpcRequest.getParamTypes()); Object returnObject = method.invoke(service,rpcRequest.getParameters()); objectOutputStream.writeObject(RpcResponse.success(returnObject)); objectOutputStream.flush(); }catch (IOException | ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e){ logger.error("调用时发生错误:",e); } } } }
5. 测试
-
将我们的HelloService注册到RpcServer当中:
public class TestServer { public static void main(String[] args) { HelloService helloService = new HelloServiceImpl(); RpcServer rpcServer = new RpcServer(); //开放在本机的9000端口 rpcServer.register(helloService,9000); } }
-
客户端通过动态代理,生成对象,并且调用,并自动帮我们向服务端发送请求:
public class TestClient { public static void main(String[] args) { RpcClientProxy proxy = new RpcClientProxy("127.0.0.1",9000); HelloService helloService = proxy.getProxy(HelloService.class); HelloObject object = new HelloObject(15,"Good Morning"); String res = helloService.hello(object); System.out.println(res); } }
-
结果:
-
服务端
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-441slDQ2-1639496107506)(DedSec%20RPC.assets/image-20211008031923044-3634369.png)]
-
客户端
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lD2s0wAl-1639496107507)(DedSec%20RPC.assets/image-20211008031949391.png)]
-
三、逐步实现
1. 注册多个服务
观察到Server端的代码是:
public class TestServer {
public static void main(String[] args) {
HelloService helloService = new HelloServiceImpl();
RpcServer rpcServer = new RpcServer();
//开放在本机的9000端口
rpcServer.register(helloService,9000);
}
}
这种形式的实现有:简单、易理解的特点,但是它造成的缺点是:一个服务器只能注册一个服务;这不是我们想要的;
那么我们需要做到另一件事就是:服务注册与客户端的分离;
-
ServiceRegistry:一个保存本地服务信息的容器,这是这个容器的通用接口
package com.dedsec.rpc.registry; public interface ServiceRegistry { //注册服务 <T> void register(T service); //获取服务 Object getService(String serviceName); }
-
DefaultServiceRegistry:给出这个容器接口的默认实现:
package com.dedsec.rpc.registry; public class DefaultServiceRegistry implements ServiceRegistry{ private static final Logger logger = LoggerFactory.getLogger(DefaultServiceRegistry.class); private final Map<String,Object> serviceMap = new ConcurrentHashMap<>(); private final Set<String> registeredService = ConcurrentHashMap.newKeySet(); /** * 注册服务 * @param service * @param <T> */ @Override public <T> void register(T service) { String serviceName = service.getClass().getCanonicalName(); //服务已经注册,则停止无需继续注册 if (registeredService.contains(serviceName)) return; //注册对应服务,加入容器当中 registeredService.add(serviceName); Class<?>[] interfaces = service.getClass().getInterfaces(); //没有注册任何接口的异常 if (interfaces.length == 0){ throw new RpcException(RpcError.SERVICE_NOT_IMPLEMENT_ANY_INTERFACE); } //挨个注册接口 for (Class<?> i : interfaces){ serviceMap.put(i.getCanonicalName(), service); } logger.info("向接口:{},注册服务:{}",interfaces,serviceName); } /** * 获取输入对应服务名的服务 * @param serviceName * @return */ @Override public Object getService(String serviceName) { Object service = serviceMap.get(serviceName); if (service == null){ throw new RpcException(RpcError.SERVICE_NOT_FOUND); } return service; } }
这个Set存储了已经注册完成的服务名,而“服务名与服务”的对应关系存在一个ConcurrentHashMap当中,这保证了并发场景获取服务的安全性;在注册服务时,默认采用这个对象实现的接口的完整类名作为服务名,例如某个对象 A 实现了接口 X 和 Y,那么将 A 注册进去后,会有两个服务名 X 和 Y 对应于 A 对象。这种处理方式也就说明了某个接口只能有一个对象提供服务。
-
RpcServer:降低耦合度
绑定Server实现和ServiceRegistry的方式会造成二者的耦合,所以采用:在创建一个Server的同时,传入一个ServiceRegistry的方式进行解耦。这里我们通过抽象对应的线程:RequestHandlerThread,封装了实际服务端执行服务的方式,原先的Run方法也就变成了Start方法;
package com.dedsec.rpc.transport.server; public class RpcServer { private static final Logger logger = LoggerFactory.getLogger(RpcServer.class); private static final int CORE_PORE_SIZE = 5; private static final int MAXIMUM_POOL_SIZE = 50; private static final int KEEP_ALIVE_TIME = 60; private static final int BLOCKING_QUEUE_CAPACITY = 100; private final ExecutorService threadPool; private RequestHandler requestHandler = new RequestHandler(); private final ServiceRegistry serviceRegistry; public RpcServer(ServiceRegistry serviceRegistry) { this.serviceRegistry = serviceRegistry; BlockingQueue<Runnable> workingQueue = new ArrayBlockingQueue<>(BLOCKING_QUEUE_CAPACITY); ThreadFactory threadFactory = Executors.defaultThreadFactory(); threadPool = new ThreadPoolExecutor(CORE_PORE_SIZE, MAXIMUM_POOL_SIZE,KEEP_ALIVE_TIME,TimeUnit.SECONDS,workingQueue,threadFactory); } public void start(int port){ try(ServerSocket serverSocket = new ServerSocket(port)){ logger.info("服务器启动..."); Socket socket; while ((socket = serverSocket.accept()) != null){ logger.info("消费者连接:{}:{}",socket.getInetAddress(),socket.getPort()); threadPool.execute(new RequestHandlerThread(socket,requestHandler,serviceRegistry)); } threadPool.shutdown(); }catch (IOException e){ logger.error("服务器启动时有错误发生"); } } }
-
RequestHandlerThread:抽象出运行在服务器中的线程
上面RpcServer的代码中:RequestHandlerThread可以理解为实际运行在服务器上的线程,所以需要传入对应的socket、requestHandler(实际反射调用的过程)、以及对应注册的服务信息serviceRegistry来支撑服务器的运行;这个线程的功能用简要的语言概括就是:通过对应的IO流读取到我们所需要的RpcRequest,再从Request当中获取我们的服务名称,最后从容器ServiceRegistry中获取对应的服务,然后用requestHandler反射调用处理器去处理任务;
package com.dedsec.rpc.transport.server; public class RequestHandlerThread implements Runnable{ private static final Logger logger = LoggerFactory.getLogger(RequestHandlerThread.class); private Socket socket; private RequestHandler requestHandler; private ServiceRegistry serviceRegistry; /** * 传入对应的要素来初始化执行"过程调用"的线程 * 这个"过程调用"就是反射调用方法的实际过程 * @param socket * @param requestHandler * @param serviceRegistry */ public RequestHandlerThread(Socket socket, RequestHandler requestHandler, ServiceRegistry serviceRegistry){ this.socket = socket; this.requestHandler = requestHandler; this.serviceRegistry = serviceRegistry; } @Override public void run() { try(ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream()); ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream())){ //从socket中读取对象(就是请求) RpcRequest rpcRequest = (RpcRequest) objectInputStream.readObject(); //获取接口名称 String interfaceName = rpcRequest.getInterfaceName(); //通过接口名称获取对应的服务对象 Object service = serviceRegistry.getService(interfaceName); //通过过程调用获取相应的结果 Object result = requestHandler.handle(rpcRequest,service); //返回对应的结果 objectOutputStream.writeObject(RpcResponse.success(result)); objectOutputStream.flush(); }catch (IOException | ClassNotFoundException e){ logger.error("调用时发生错误:", e); } } }
-
RequestHandler:反射调用处理器,封装了真正的反射调用过程;
package com.dedsec.rpc.transport.server; /** * 执行对应请求:通过反射调用 * @author piwei * @date 2021/10/17 */ public class RequestHandler { private static final Logger logger = LoggerFactory.getLogger(RequestHandler.class); /** * 返回对应方法的执行结果 * @param rpcRequest * @param service * @return */ public Object handle(RpcRequest rpcRequest, Object service){ Object result = null; try{ result = invokeTargetMethod(rpcRequest,service); logger.info("服务:{} 成功调用方法:{}",rpcRequest.getInterfaceName(),rpcRequest.getMethodName()); } catch (IllegalAccessException | InvocationTargetException e){ logger.error("调用或发送时有错误发生:",e); } return result; } /** * 反射调用过程 * @param rpcRequest * @param service * @return * @throws IllegalAccessException * @throws InvocationTargetException */ private Object invokeTargetMethod(RpcRequest rpcRequest, Object service) throws IllegalAccessException, InvocationTargetException { Method method; try{ method = service.getClass().getMethod(rpcRequest.getMethodName(),rpcRequest.getParamTypes()); }catch (NoSuchMethodException e){ return RpcResponse.fail(ResponseCode.METHOD_NOT_FOUND); } //真实的反射调用 return method.invoke(service,rpcRequest.getParameters()); } }
-
测试结果:我们成功注册了两个服务;
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-t5PNyFlT-1639496107507)(DedSec%20RPC.assets/image-20211021204044211.png)]
2. Netty传输和通用的序列化接口
将原有的传输方式从BIO改进为更为高效的NIO,不过在这里我没有借助Java的原声NIO实现的方式,而是采用了Netty框架实现;另外,通过实现一个通用的序列化接口,为多种序列化做准备,并更进一步制定我们的传输协议;
-
抽象出:RpcServer和RpcClient两个接口
之所以这样做的原因是:之前对于Server和Client的实现方式是Socket,但是接下来对于它的改进将采用Netty的方式来实现,所以抽象出接口,方便未来对于Server和Client实现的进一步扩展;
package com.dedsec.rpc.transport.server; public interface RpcServer { void start(int port); }
package com.dedsec.rpc.transport.client; public interface RpcClient { Object sendRequest(RpcRequest request, String host, int port); }
而使用Socket方式实现的Server和Client只需要实现上面两个接口,不用做其他多余改动;
值得注意的一个改动是:为了保证服务信息在服务端注册的唯一性,我们将容器实现中的Map和Set变为static;
-
NettyServer的实现很传统:
package com.dedsec.rpc.transport.server.netty; public class NettyServer implements RpcServer { private static final Logger logger = LoggerFactory.getLogger(NettyServer.class); @Override public void start(int port) { EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup(); try{ ServerBootstrap serverBootstrap = new ServerBootstrap(); serverBootstrap.group(bossGroup,workerGroup) .channel(NioServerSocketChannel.class) .handler(new LoggingHandler(LogLevel.ERROR)) .option(ChannelOption.SO_BACKLOG,256) .option(ChannelOption.SO_KEEPALIVE,true) .childOption(ChannelOption.TCP_NODELAY,true) .childHandler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast(new CommonEncoder(new JsonSerializer())); pipeline.addLast(new CommonDecoder()); pipeline.addLast(new NettyServerHandler()); } }); ChannelFuture future = serverBootstrap.bind(port).sync(); future.channel().closeFuture().sync(); } catch (InterruptedException e){ logger.error("启动服务器时有错误发生:",e); } finally { bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } } }
Netty中有一个重要的设计模式:责任链模式,责任链上有多个处理器,每一个处理器会对数据进行加工,并将处理后的数据传给下一个处理器。代码中:
CommonEncoder 编码器 CommonDecoder 解码器 NettyServrerHandler 数据处理器
数据从外部传入时需要解码,而数据向外传出时需要编码,这非常类似计算机网络分层模型的设计思想;
-
NettyClient的实现也很类似:
package com.dedsec.rpc.transport.client.netty; public class NettyClient implements RpcClient { private static final Logger logger = LoggerFactory.getLogger(NettyClient.class); private String host; private int port; private static final Bootstrap bootstrap; public NettyClient(String host, int port) { this.host = host; this.port = port; } /** * 在静态代码块中配置好了Netty客户端,等待发送数据时启动,channel将RpcRequest对象写出, * 并且等待服务端返回结果; * 注意这里的发送是非阻塞的,所以发送后会立刻返回,而无法得到结果; */ static { EventLoopGroup group = new NioEventLoopGroup(); bootstrap = new Bootstrap(); bootstrap.group(group) .channel(NioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE,true) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast(new CommonDecoder()) .addLast(new CommonEncoder(new JsonSerializer())) .addLast(new NettyClientHandler()); } }); } @Override public Object sendRequest(RpcRequest request) { try{ ChannelFuture future = bootstrap.connect(host,port).sync(); logger.info("客户端连接到服务器^_^{}:{}",host,port); Channel channel = future.channel(); if (channel != null){ channel.writeAndFlush(request).addListener(future1 -> { if (future1.isSuccess()){ logger.info(String.format("客户端发送消息:%s",request.toString())); } else { logger.error("发送消息时有错误发生:",future1.cause()); } }); channel.closeFuture().sync(); //这里通过AttributeKey的方式阻塞获得返回结果; //通过这种方式获得全局可见的返回结果,在获得返回结果RpcResponse后,将这个对象 //以Key为rpcResponse放入ChannelHandlerContext中,这里就可以立刻获得结果并返回; //在NettyClientHandler中看到放入的过程 AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse"); RpcResponse rpcResponse = channel.attr(key).get(); return rpcResponse.getData(); } } catch (InterruptedException e){ logger.error("发送消息时有错误发生:",e); } return null; } }
要点已经写在了注释当中;
-
自定义协议与编解码器
1. 4字节魔数,代表这是一个协议包; 2. Package Type,代表这是一个调用请求 or 响应请求; 3. Serializer Type, 标明了实际数据所使用的序列化器; 4. Data Length, 实际的数据长度;主要是为了防止粘包;
-
通用编码器:CommonEncoder
package com.dedsec.rpc.codec; public class CommonEncoder extends MessageToByteEncoder { private static final int MAGIC_NUMBER = 0xCAFEBABE; private final CommonSerializer serializer; public CommonEncoder(CommonSerializer serializer){ this.serializer = serializer; } @Override protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception { out.writeInt(MAGIC_NUMBER); if (msg instanceof RpcRequest){ out.writeInt(PackageType.REQUEST_PACK.getCode()); } else { out.writeInt(PackageType.RESPONSE_PACK.getCode()); } out.writeInt(serializer.getCode()); byte[] bytes = serializer.serialize(msg); out.writeInt(bytes.length); out.writeBytes(bytes); } }
CommonEncoder继承了MessageToByteEncoder类,这个类的作用就是要把实际发送的对象转化为Byte数组,CommonEncoder的工作就是将RpcRequest 或者 RpcResponse包装成协议包,根据上面的格式,将各个字段写到管道里就可以,这里serializer.getCode()获取序列化器的编号,之后使用传入的序列化器将请求或响应包序列化为字节数组写入管道即可;
-
CommonDecoder的工作与之类似:
package com.dedsec.rpc.codec; public class CommonDecoder extends ReplayingDecoder { private static final Logger logger = LoggerFactory.getLogger(CommonDecoder.class); private static final int MAGIC_NUMBER = 0xCAFEBABE; @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { //读入数据包的前4个字节,看是否是合法的数据包 int magic = in.readInt(); if (magic != MAGIC_NUMBER){ logger.error("无法识别的协议包:{}",magic); throw new RpcException(RpcError.UNKNOWN_PROTOCOL); } //读入下面4个字节,看是什么类型的数据包:请求 还是 响应? int packageCode = in.readInt(); Class<?> packageClass; //什么类型的数据包? if (packageCode == PackageType.REQUEST_PACK.getCode()){ packageClass = RpcRequest.class; } else if (packageCode == PackageType.RESPONSE_PACK.getCode()){ packageClass = RpcResponse.class; } else { logger.error("无法识别的协议包:{}",packageCode); throw new RpcException(RpcError.UNKNOWN_PACKAGE_TYPE); } //读入接下来4个字节,看使用了什么编解码器,找出对应的反序列化器 int serializerCode = in.readInt(); CommonSerializer serializer = CommonSerializer.getByCode(serializerCode); if (serializer == null){ logger.error("无法识别的反序列化器:{}",serializerCode); throw new RpcException(RpcError.UNKNOWN_SERIALIZER); } //读入接下来4个字节,看数据长度是多少,并为其分配接收空间,防止粘包 int length = in.readInt(); byte[] bytes = new byte[length]; //读入字节数组,并采用反序列化方式,得到对象 in.readBytes(bytes); Object obj = serializer.deserialize(bytes,packageClass); out.add(obj); } }
-
序列化接口:CommonSerialize;主要是通用方法、根据不同的编码器获取不同的解码器;
package com.dedsec.rpc.serializer; public interface CommonSerializer { byte[] serialize(Object obj); Object deserialize(byte[] bytes, Class<?> clazz); int getCode(); static CommonSerializer getByCode(int code){ switch (code){ case 1: return new JsonSerializer(); default: return null; } } }
JsonSerializer:
package com.dedsec.rpc.serializer; public class JsonSerializer implements CommonSerializer{ private static final Logger logger = LoggerFactory.getLogger(JsonSerializer.class); private ObjectMapper objectMapper = new ObjectMapper(); @Override public byte[] serialize(Object obj) { try{ return objectMapper.writeValueAsBytes(obj); } catch (JsonProcessingException e){ logger.error("序列化时有错误发生:{}",e.getMessage()); e.printStackTrace(); return null; } } @Override public Object deserialize(byte[] bytes, Class<?> clazz) { try{ Object obj = objectMapper.readValue(bytes,clazz); if (obj instanceof RpcRequest){ obj = handleRequest(obj); } return obj; } catch (IOException e){ logger.error("反序列化有错误发生:{}",e.getMessage()); e.printStackTrace(); return null; } } /** * 处理请求方法 * 由于使用JSON序列化和反序列化Object数组,无法保证反序列化后仍然为原实例类型 * 需要重新判断处理 * @param obj * @return * @throws IOException */ private Object handleRequest(Object obj) throws IOException{ RpcRequest rpcRequest = (RpcRequest) obj; for (int i = 0; i < rpcRequest.getParamTypes().length; i++){ Class<?> clazz = rpcRequest.getParamTypes()[i]; if (!clazz.isAssignableFrom(rpcRequest.getParameters()[i].getClass())){ byte[] bytes = objectMapper.writeValueAsBytes(rpcRequest.getParameters()[i]); rpcRequest.getParameters()[i] = objectMapper.readValue(bytes,clazz); } } return rpcRequest; } @Override public int getCode() { return SerializerCode.valueOf("JSON").getCode(); } }
-
NettyServerHandler:处理RpcRequest请求的处理器,用于接收RpcRequest,并将调用结果封装成RpcResponse发送出去;
package com.dedsec.rpc.transport.server.netty; /** * Netty中处理RpcRequest的Handler * @author piwei * @date 2021/10/24 */ public class NettyServerHandler extends SimpleChannelInboundHandler<RpcRequest> { private static final Logger logger = LoggerFactory.getLogger(NettyServerHandler.class); private static RequestHandler requestHandler; private static ServiceRegistry serviceRegistry; static { requestHandler = new RequestHandler(); serviceRegistry = new DefaultServiceRegistry(); } @Override protected void channelRead0(ChannelHandlerContext ctx, RpcRequest msg) throws Exception { try{ logger.info("服务器收到请求^_^:{}",msg); String interfaceName = msg.getInterfaceName(); Object service = serviceRegistry.getService(interfaceName); Object result = requestHandler.handle(msg,service); ChannelFuture future = ctx.writeAndFlush(RpcResponse.success(result)); future.addListener(ChannelFutureListener.CLOSE); } finally { ReferenceCountUtil.release(msg); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { logger.error("处理过程调用时,有错误发生"); cause.printStackTrace(); ctx.close(); } }
-
NettyClientHandler:处理接收到的RpcResponse对象;
package com.dedsec.rpc.transport.client.netty; public class NettyClientHandler extends SimpleChannelInboundHandler<RpcResponse> { private static final Logger logger = LoggerFactory.getLogger(NettyClientHandler.class); @Override protected void channelRead0(ChannelHandlerContext ctx, RpcResponse msg) throws Exception { try{ logger.info(String.format("客户端接收到消息:%s",msg)); AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse"); ctx.channel().attr(key).set(msg); ctx.channel().close(); } finally { ReferenceCountUtil.release(msg); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { logger.error("过程调用时有错误发生"); cause.printStackTrace(); ctx.close(); } }
3. 更强劲的序列化手段–Kryo
上面我们已经搭建好了通用化的编解码框架,完成了我们的序列化框架且具有一定的扩展性,但我们使用的Json序列化器有以下几个缺点:
- 某个类的属性反序列化的时候,如果属性是Object,就会被直接反序列化为String,导致出错;
- Json序列化器是基于字符串(Json字符串)的,占较大空间且速度较慢;
所以,为了使序列化手段鲁棒性更强且性能高效、易用,可以实现一个Kryo序列化器,Kryo是:
- 快速高效的Java对象序列化器,高性能、高效、易用;
- 基于字节的序列化,对空间的利用率很高,网络传输时可以减小数据包体积;
- 序列化时会记录对象的类型信息,这样反序列化就不会出现Json将Object属性直接翻译为String的错误了;
- KryoSerializer:Kryo的序列化器
package com.dedsec.rpc.serializer;
/**
* Kryo序列化器
* 序列化时,先创建一个Output对象,接着使用writeObject方法将对象写入Output中,最后调用Output对象的toByte()方法,
* 即可以获得对象的字节数组;
* 反序列化是从Input对象中直接readObject,这里只需要传入对象的类型,而不需要具体的传入每一个属性的类型信息;
* 这便是优于JsonSerializer的地方;
* @author piwei
* @date 2021/10/25
*/
public class KryoSerialize implements CommonSerializer{
private static final Logger logger = LoggerFactory.getLogger(KryoSerialize.class);
/**
* Kryo可能存在线程问题,推荐放在ThreadLocal中,一个线程一个Kryo;
*/
private static final ThreadLocal<Kryo> kryoThreadLocal = ThreadLocal.withInitial(()->{
Kryo kryo = new Kryo();
kryo.register(RpcResponse.class);
kryo.register(RpcRequest.class);
kryo.setReferences(true);
kryo.setRegistrationRequired(false);
return kryo;
});
@Override
public byte[] serialize(Object obj) {
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
Output output = new Output(byteArrayOutputStream)){
Kryo kryo = kryoThreadLocal.get();
kryo.writeObject(output,obj);
kryoThreadLocal.remove();
return output.toBytes();
} catch (Exception e){
logger.error("Kryo序列化时有错误发生",e);
throw new SerializeException("Kryo序列化时有错误发生");
}
}
@Override
public Object deserialize(byte[] bytes, Class<?> clazz) {
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
Input input = new Input(byteArrayInputStream)){
Kryo kryo = kryoThreadLocal.get();
Object o = kryo.readObject(input,clazz);
kryoThreadLocal.remove();
return o;
} catch (Exception e){
logger.error("Kryo反序列化时有错误发生",e);
throw new SerializeException("Kryo反序列化时有错误发生");
}
}
@Override
public int getCode() {
return SerializerCode.valueOf("KRYO").getCode();
}
}
4. Nacos实现的服务注册与发现
对于我们的框架来说,它已经实现了最基本的服务调用功能,但是服务端的地址是固定写在客户端的,相当于我们主动告诉了客户端服务的所在,那么如何让客户端主动发现服务?如果正在使用的服务器A宕机,我们如何让客户端再去发现其他的服务?
在分布式框架中,有一个重要的组建,就是服务注册中心,它用于保存多个服务提供者的信息,每个服务提供者在启动的时候都需要向注册中心注册自己所拥有的服务。
常见的服务注册中心有:Eureka、Zookeeper、Nacos;
-
Nacos安装教程:
https://blog.csdn.net/weixin_43990171/article/details/121269101?spm=1001.2014.3001.5501
-
引入Nacos
<dependency> <groupId>com.alibaba.nacos</groupId> <artifactId>nacos-client</artifactId> <version>1.3.0</version> </dependency>
-
修改容器名,接下来服务注册中心是ServiceRegistry
这里将SeviceRegistry变更为:ServiceProvider;然后SeviceRegistry变更为服务注册表使用;
public interface ServiceRegistry { /** * 向服务中心注册服务 * @param serviceName 服务名 * @param inetSocketAddress 提供服务的地址 */ void register(String serviceName, InetSocketAddress inetSocketAddress); /** * 根据Service名,返回一个InetSocketAddress * 应该可以简单理解为对于IP和PORT的封装 * @param serviceName 服务名 * @return IP + Port */ InetSocketAddress lookupService(String serviceName); }
-
Nacos注册中心实现类
package com.dedsec.rpc.registry; public class NacosServiceRegistry implements ServiceRegistry{ private static final Logger logger = LoggerFactory.getLogger(NacosServiceRegistry.class); //这一部分应该可以写在配置文件中 private static final String SERVER_ADDR = "127.0.0.1:8848"; //用于连接Nacos private static final NamingService namingService; //连接Nacos static { try { namingService = NamingFactory.createNamingService(SERVER_ADDR); } catch(NacosException e) { logger.error("连接Nacos的时候发生了错误:",e); throw new RpcException(RpcError.FAILED_TO_CONNECT_TO_SERVICE_REGISTRY); } } /** * 注册 * @param serviceName 服务名 * @param inetSocketAddress 提供服务的地址 */ @Override public void register(String serviceName, InetSocketAddress inetSocketAddress) { try { namingService.registerInstance(serviceName, inetSocketAddress.getHostName(), inetSocketAddress.getPort()); } catch (NacosException e){ logger.error("注册服务时发生错误:",e); throw new RpcException(RpcError.REGISTER_SERVICE_FAILED); } } /** * * @param serviceName 服务名 * @return */ @Override public InetSocketAddress lookupService(String serviceName) { try { List<Instance> instances = namingService.getAllInstances(serviceName); Instance instance = instances.get(0); return new InetSocketAddress(instance.getIp(),instance.getPort()); } catch (NacosException e) { logger.error("获取服务时有错误发生:",e); } return null; } }
通过创建一个NamingService来连接Nacos;
连接的过程写在了静态代码块中,在类加载时自动连接;
通过register方法注册服务到服务中心,利用了namingService.registerInstance()方法;
通过lookupService方法,可以获得所有提供serviceName服务的提供者的信息,这里默认返回第0个提供者,但是这地方其实应该运用负载均衡算法;
到此,大概明白做了什么改动,首先就是通过实现Nacos的服务注册类,向Nacos服务器注册服务,然后Nacos便存储了对应的服务信息; 之前的容器还是原来的功能,为对应请求提供一个类去服务; 总而言之,客户端现在不需要知道谁具有某个服务了,只需要向Nacos询问就可以,Nacos可以通过一些算法、负载均衡的策略,返回一个最好的服务给客户端;
-
注册服务
为Server的实现添加一个方法:注册服务;
另外,再添加一个设置通信期间序列化器的方法,因为不同的服务之间可能采用不同的序列化器;
/** * 抽象服务端接口 * @author piwei * @date 2021/10/24 */ public interface RpcServer { void start(); //发布服务 <T> void publishService(Object service, Class<T> serviceClass); //设置序列化器 void setSerializer(CommonSerializer commonSerializer); }
然后,在NettyServer的实现中,增加这两个方法的实现:
package com.dedsec.rpc.transport.server.netty; /** * Netty实现的服务端 * @author piwei * @date 2021/10/24 */ public class NettyServer implements RpcServer { private static final Logger logger = LoggerFactory.getLogger(NettyServer.class); private String host; private int port; private ServiceRegistry serviceRegistry; private ServiceProvider serviceProvider; private CommonSerializer serializer; public NettyServer(String host, int port) { this.host = host; this.port = port; this.serviceRegistry = new NacosServiceRegistry(); this.serviceProvider = new ServiceProviderImpl(); } @Override public void setSerializer(CommonSerializer commonSerializer) { this.serializer = commonSerializer; } @Override public void start() { EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup(); try{ ServerBootstrap serverBootstrap = new ServerBootstrap(); //关于Bootstrap的一些配置: serverBootstrap.group(bossGroup,workerGroup) .channel(NioServerSocketChannel.class) .handler(new LoggingHandler(LogLevel.ERROR)) .option(ChannelOption.SO_BACKLOG,256) .option(ChannelOption.SO_KEEPALIVE,true) .childOption(ChannelOption.TCP_NODELAY,true) .childHandler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); //编码器 pipeline.addLast(new CommonEncoder(new KryoSerialize())); //pipeline.addLast(new CommonEncoder(new JsonSerializer())); //解码器 pipeline.addLast(new CommonDecoder()); //数据处理器 pipeline.addLast(new NettyServerHandler()); } }); ChannelFuture future = serverBootstrap.bind(port).sync(); future.channel().closeFuture().sync(); } catch (InterruptedException e){ logger.error("启动服务器时有错误发生:",e); } finally { bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } } @Override public <T> void publishService(Object service, Class<T> serviceClass) { if (serializer == null){ logger.error("未设置序列化器"); throw new RpcException(RpcError.SERVICE_NOT_FOUND); } //注册在服务器本地内存 serviceProvider.addServiceProvider(service); //注册在Nacos serviceRegistry.register(serviceClass.getCanonicalName(), new InetSocketAddress(host,port)); start(); } }
观察到publishService方法一边注册服务到了Nacos,一边注册服务到本地的注册表;
此处注册完毕直接调用start方法是一个不好的处理,导致一台服务端机器只能注册一个服务;
-
发现服务
以前的Host和Port是通过写死获取的,现在的Host和Post通过Nacos获取,那么对于Client的实现修改如下:
package com.dedsec.rpc.transport.client.netty; /** * Netty的客户端实现 * @author piwei * @date 2021/10/25 */ public class NettyClient implements RpcClient { private static final Logger logger = LoggerFactory.getLogger(NettyClient.class); private static final Bootstrap bootstrap; private final ServiceRegistry serviceRegistry; private CommonSerializer serializer; public NettyClient() { this.serviceRegistry = new NacosServiceRegistry(); } /** * 在静态代码块中配置好了Netty客户端,等待发送数据时启动,channel将RpcRequest对象写出, * 并且等待服务端返回结果; * 注意这里的发送是非阻塞的,所以发送后会立刻返回,而无法得到结果; */ static { EventLoopGroup group = new NioEventLoopGroup(); bootstrap = new Bootstrap(); bootstrap.group(group) .channel(NioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE,true) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast(new CommonDecoder()) //.addLast(new CommonEncoder(new JsonSerializer())) .addLast(new CommonEncoder(new KryoSerialize())) .addLast(new NettyClientHandler()); } }); } @Override public void setSerializer(CommonSerializer commonSerializer) { this.serializer = commonSerializer; } @Override public Object sendRequest(RpcRequest request) { if (serializer == null){ logger.error("未设置序列化器"); throw new RpcException(RpcError.SERIALIZER_NOT_FOUND); } try{ //通过Nacos 获取对应服务的地址,这里可以得到最优的结果,取决于负载均衡的处理 InetSocketAddress inetSocketAddress = serviceRegistry.lookupService(request.getInterfaceName()); //通过ChannelProvider获取对应的Channel Channel channel = ChannelProvider.get(inetSocketAddress, serializer); if (channel.isActive()){ channel.writeAndFlush(request).addListener(future1 -> { if (future1.isSuccess()){ logger.info(String.format("客户端发送消息:%s",request.toString())); } else { logger.error("发送消息时有错误发生:",future1.cause()); } }); channel.closeFuture().sync(); //这里通过AttributeKey的方式阻塞获得返回结果; //通过这种方式获得全局可见的返回结果,在获得返回结果RpcResponse后,将这个对象 //以Key为rpcResponse放入ChannelHandlerContext中,这里就可以立刻获得结果并返回; //在NettyClientHandler中看到放入的过程 AttributeKey<RpcResponse> key = AttributeKey.valueOf("rpcResponse"); RpcResponse rpcResponse = channel.attr(key).get(); return rpcResponse.getData(); } } catch (InterruptedException e){ logger.error("发送消息时有错误发生:",e); } return null; } }
观察到获取Channel的时候,采用了ChannelProvider
package com.dedsec.rpc.transport.client.netty; /** * 通道提供类 * 根据服务提供商IP + 序列化器号组成的Key在这个Map中寻求到一个用于客户端与服务端通信的Channel * @author piwei * @date 2021/11/14 */ public class ChannelProvider { private static final Logger logger = LoggerFactory.getLogger(ChannelProvider.class); private static EventLoopGroup eventLoopGroup; private static Bootstrap bootstrap = initializeBootstrap(); private static Map<String, Channel> channels = new ConcurrentHashMap<>(); public static Channel get(InetSocketAddress inetSocketAddress, CommonSerializer serializer) throws InterruptedException { String key = inetSocketAddress.toString() + serializer.getCode(); if (channels.containsKey(key)) { Channel channel = channels.get(key); if(channels != null && channel.isActive()) { return channel; } else { channels.remove(key); } } bootstrap.handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) { /*自定义序列化编解码器*/ // RpcResponse -> ByteBuf ch.pipeline().addLast(new CommonEncoder(serializer)) .addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS)) .addLast(new CommonDecoder()) .addLast(new NettyClientHandler()); } }); Channel channel = null; try { channel = connect(bootstrap, inetSocketAddress); } catch (ExecutionException e) { logger.error("连接客户端时有错误发生", e); return null; } channels.put(key, channel); return channel; } private static Channel connect(Bootstrap bootstrap, InetSocketAddress inetSocketAddress) throws ExecutionException, InterruptedException { CompletableFuture<Channel> completableFuture = new CompletableFuture<>(); bootstrap.connect(inetSocketAddress).addListener((ChannelFutureListener) future -> { if (future.isSuccess()) { logger.info("客户端连接成功!"); completableFuture.complete(future.channel()); } else { throw new IllegalStateException(); } }); return completableFuture.get(); } private static Bootstrap initializeBootstrap() { eventLoopGroup = new NioEventLoopGroup(); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(eventLoopGroup) .channel(NioSocketChannel.class) //连接的超时时间,超过这个时间还是建立不上的话则代表连接失败 .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000) //是否开启 TCP 底层心跳机制 .option(ChannelOption.SO_KEEPALIVE, true) //TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。 .option(ChannelOption.TCP_NODELAY, true); return bootstrap; } }
-
测试
NettyTestClient
public class NettyTestClient { public static void main(String[] args) { //创建客户端实例,现在不用通过写定IP于Port的方式,自动会从Nacos中获取 RpcClient client = new NettyClient(); //设定序列化器 client.setSerializer(new KryoSerialize()); //代理 RpcClientProxy rpcClientProxy = new RpcClientProxy(client); //服务代理 HelloService helloService = rpcClientProxy.getProxy(HelloService.class); //服务通信实体 HelloObject object = new HelloObject(12,"This is a message"); //获取发送请求后的结果 String res = helloService.hello(object); System.out.println(res); } }
NettyTestServer
public class NettyTestServer { public static void main(String[] args) { //在服务端注册对应的服务实现 HelloService helloService = new HelloServiceImpl(); CrawlerService crawlerService = new CrawlerServiceImpl(); //启动server实例,绑定好服务端的IP于port,注册在Nacos供Client发现; NettyServer server = new NettyServer("127.0.0.1",9999); //设置序列化器 server.setSerializer(new KryoSerialize()); //注册服务,这里会自动开启,所以不会注册多个服务 server.publishService(helloService,HelloService.class); } }
5. 自动注销服务与负载均衡
-
服务自动注销
在服务器关闭之后,对应的服务应该被撤销,这很简单,但是如何知晓服务器关闭的时间呢?我们如何指派这个关闭的任务?这里需要用到钩子;
钩子:Hook翻译成中文就是勾子的意思,在Java中它表示在事件到达终点前进行拦截或监控的一种行为;
/** * 钩子函数,JVM退出时,注销Nacos中的服务 * @author piwei * @date 2021/11/18 */ public class ShutdownHook { private static final Logger logger = LoggerFactory.getLogger(ShutdownHook.class); private final ExecutorService threadPool = ThreadPoolFactory.createDefaultThreadPool("shutdown-hook"); private static final ShutdownHook shutdownHook = new ShutdownHook(); public static ShutdownHook getShutdownHook(){ return shutdownHook; } /** * 添加注销所有服务的Hook方法 */ public void addClearAllHook(){ logger.info("关闭后将自动注销所有服务"); //虚拟机运行时环境,注册钩子函数,该函数在JVM退出之前执行; Runtime.getRuntime().addShutdownHook(new Thread(()->{ NacosUtils.clearRegistry(); threadPool.shutdown(); })); } }
使用单例模式创建Hook函数对象,JVM在退出时,创建一个新的线程完成注销工作,钩子函数在JVM关闭前调用;
NettyServer.java
ChannelFuture future = serverBootstrap.bind(port).sync(); //注册Hook函数 ShutdownHook.getShutdownHook().addClearAllHook(); future.channel().closeFuture().sync();
这样在服务器关闭之后,可以看到Nacos中注册的服务信息被注销了;
-
负载均衡策略
实现负载均衡,将上一节中,挑选服务的方法抽象成一个接口:
/** * 服务发现 * @author piwei * @date 2021/11/18 */ public interface ServiceDiscovery { /** * 根据服务名称查找服务实体 * @param serviceName * @return */ InetSocketAddress lookupService(String serviceName); }
然后,我们来实现一些平衡器:
随机选择平衡器:
/** * 随机选择平衡器 * @author piwei * @date 2021/11/18 */ public class RandomLoadBalancer implements LoadBalancer{ @Override public Instance select(List<Instance> instances) { int randomIndex = new Random().nextInt(instances.size()); return instances.get(randomIndex); } }
轮询平衡器:
/** * 轮转负载均衡算法 * @author piwei * @date 2021/11/18 */ public class RoundRobinLoadBalancer implements LoadBalancer{ private int index = 0; @Override public Instance select(List<Instance> instances) { index = index >= instances.size() ? (index % instances.size()) : index; return instances.get(index++); } }
可以进一步实现一个加权 随机/轮转 平衡器,除此之外还有:一致性哈希、平滑加权轮转算法;
最后,在NacosServiceDiscovery中集成负载均衡:
/** * Nacos服务发现类 * @author piwei * @date 2021/11/18 */ public class NacosServiceDiscovery implements ServiceDiscovery{ private static final Logger logger = LoggerFactory.getLogger(NacosServiceDiscovery.class); private final LoadBalancer loadBalancer; public NacosServiceDiscovery(LoadBalancer loadBalancer) { if (loadBalancer == null){ this.loadBalancer = new RandomLoadBalancer(); } else { this.loadBalancer = loadBalancer; } } @Override public InetSocketAddress lookupService(String serviceName) { try { List<Instance> instances = NacosUtils.getAllInstance(serviceName); Instance instance = loadBalancer.select(instances); return new InetSocketAddress(instance.getIp(),instance.getPort()); } catch (NacosException e) { logger.error("获取服务的时候有错误发生:",e); } return null; } }
6. 服务端自动注册服务
客户端的实现我们暂时告一段落,我们重新着眼于服务端的实现:
public class NettyTestServer {
public static void main(String[] args) {
//在服务端注册对应的服务实现
HelloService helloService = new HelloServiceImpl();
CrawlerService crawlerService = new CrawlerServiceImpl();
//启动server实例,绑定好服务端的IP于port,注册在Nacos供Client发现;
NettyServer server = new NettyServer("127.0.0.1",9999);
//设置序列化器
server.setSerializer(new KryoSerialize());
//注册服务
server.publishService(helloService,HelloService.class);
}
}
我们注意到,关于服务的注册需要手动的一个个去完成,这显然在服务数量增多,体量更大的时候,是不合理的;我们可以尝试利用Spring框架中的思想,基于注解去注册我们的服务;
-
定义相关注解
服务注解:Service.java
/** * Service标识该类为一个服务提供类 * Service注解的值定义为该服务的名称,默认值是该类的完整类名; * @author piwei * @date 2021/11/19 */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) public @interface Service { public String name() default ""; }
启动类注解:ServiceScan.java
/** * 扫描范围 * 一般标识在启动类,标识服务的扫描的包范围; * ServiceScan的值定义为扫描范围的根包,默认值为入口类所在的包,扫描时会扫描该包及其子包下所有的类 * 找到标记有Service的类,并注册; * @author piwei * @date 2021/11/19 */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) public @interface ServiceScan { public String value() default ""; }
Service注解的值定义为服务名,而ServiceScan注解的值定义为扫描范围的根包名,默认值为入口类所在的包;
-
工具类:ReflectUtil
package com.dedsec.utils; /** * 这个类是一系列工具方法 * @author piwei * @date 2021/11/19 */ public class ReflectUtil { public static String getStackTrace(){ StackTraceElement[] stack = new Throwable().getStackTrace(); return stack[stack.length - 1].getClassName(); } public static Set<Class<?>> getClasses(String packageName) { Set<Class<?>> classes = new LinkedHashSet<>(); boolean recursive = true; String packageDirName = packageName.replace('.', '/'); Enumeration<URL> dirs; try { dirs = Thread.currentThread().getContextClassLoader().getResources( packageDirName); // 循环迭代下去 while (dirs.hasMoreElements()) { // 获取下一个元素 URL url = dirs.nextElement(); // 得到协议的名称 String protocol = url.getProtocol(); // 如果是以文件的形式保存在服务器上 if ("file".equals(protocol)) { // 获取包的物理路径 String filePath = URLDecoder.decode(url.getFile(), "UTF-8"); // 以文件的方式扫描整个包下的文件 并添加到集合中 findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes); } else if ("jar".equals(protocol)) { // 如果是jar包文件 // 定义一个JarFile JarFile jar; try { // 获取jar jar = ((JarURLConnection) url.openConnection()) .getJarFile(); // 从此jar包 得到一个枚举类 Enumeration<JarEntry> entries = jar.entries(); // 同样的进行循环迭代 while (entries.hasMoreElements()) { // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件 JarEntry entry = entries.nextElement(); String name = entry.getName(); // 如果是以/开头的 if (name.charAt(0) == '/') { // 获取后面的字符串 name = name.substring(1); } // 如果前半部分和定义的包名相同 if (name.startsWith(packageDirName)) { int idx = name.lastIndexOf('/'); // 如果以"/"结尾 是一个包 if (idx != -1) { // 获取包名 把"/"替换成"." packageName = name.substring(0, idx) .replace('/', '.'); } // 如果可以迭代下去 并且是一个包 if ((idx != -1) || recursive) { // 如果是一个.class文件 而且不是目录 if (name.endsWith(".class") && !entry.isDirectory()) { // 去掉后面的".class" 获取真正的类名 String className = name.substring( packageName.length() + 1, name .length() - 6); try { // 添加到classes classes.add(Class .forName(packageName + '.' + className)); } catch (ClassNotFoundException e) { // log // .error("添加用户自定义视图类错误 找不到此类的.class文件"); e.printStackTrace(); } } } } } } catch (IOException e) { // log.error("在扫描用户定义视图时从jar包获取文件出错"); e.printStackTrace(); } } } } catch (IOException e) { e.printStackTrace(); } return classes; } private static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, Set<Class<?>> classes) { // 获取此包的目录,建立一个File File dir = new File(packagePath); // 如果不存在或者 也不是目录就直接返回 if (!dir.exists() || !dir.isDirectory()){ // logs return; } // 如果存在,就获取包下的所有文件 包括目录 File[] dirFiles = dir.listFiles(new FileFilter() { //自定义过滤过则:如果可以循环 @Override public boolean accept(File file) { return (recursive && file.isDirectory() || (file.getName().endsWith(".class"))); } }); // 循环所有文件 for (File file : dirFiles){ // 如果是目录,则继续扫描 if (file.isDirectory()){ findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes); } else { // 如果是Java的类文件,去掉后面的.class,只留下类名 String className = file.getName().substring(0,file.getName().length() - 6); try { // 添加到集合中去 classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + "." + className)); } catch (ClassNotFoundException e) { e.printStackTrace(); } } } } }