背景
本文实现一个类似于nginx或gateway的反向代理网关,实现思路是访客通过网络请求反向代理服务,代理服务连接到真实服务,维护访客和真实服务的数据交互。
这个实现和之前的内网穿透项目思路相似,只不过内网穿透是由客户端主动跟代理服务维护连接,这个是代理服务主动和真实服务连接。
本文用MySQL做真实服务的配置功能,只实现了配置IP、端口和请求路径的配置。
实践
项目结构
1、cc-gateway项目:实现反向代理功能
原理分析
反向代理的实现过程主要分两步
1、启动服务端,这时代理服务监听8888端口(默认8888)
2、访客通过访问代理服务8888端口(例如http://127.0.0.1:8888/sso),代理服务接收到请求后解析请求路径得到(/sso),根据这个路径查询数据库配置,如果匹配到(/sso)对应的真实服务的IP和端口,那么代理服务会发起与真实服务的连接,并建立访客和真实服务的数据传输通道。
这两步最终形成了(访客-代理-真实服务)完整的通道。
代码实现
数据库设计
建表
CREATE TABLE `server` (
`id` varchar(32) NOT NULL COMMENT '主键ID',
`name` varchar(255) DEFAULT NULL COMMENT '服务名称',
`code` varchar(255) DEFAULT NULL COMMENT '服务标识',
`ip` varchar(255) DEFAULT NULL COMMENT '服务IP',
`port` int(11) DEFAULT NULL COMMENT '服务端口',
`ip_type` varchar(255) DEFAULT NULL COMMENT '服务IP类型(ipv4,ipv6)',
`weight` double DEFAULT NULL COMMENT '服务权重',
`status` int(1) DEFAULT NULL COMMENT '服务状态(运行中、掉线)',
`able` int(1) DEFAULT NULL COMMENT '操作状态(启用、禁用)',
`gray` varchar(255) DEFAULT NULL COMMENT '灰度信息',
`sno` int(11) DEFAULT NULL COMMENT '排序',
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='服务信息表';
填点数据
### 数据以自己真实服务的配置修改
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('1', 'im', '/im', '10.0.0.2', 8889, '1', 1, 1, 1, '', 1);
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('2', '/', '/', '10.0.0.2', 8886, '1', 1, 1, 1, '', 1);
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('3', 'sso', '/sso', '10.0.0.3', 8885, '1', 1, 1, 1, '', 1);
cc-gateway项目
pom文件
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.cc</groupId>
<artifactId>cc-gateway</artifactId>
<version>1.0-SNAPSHOT</version>
<name>cc-gateway</name>
<url>http://maven.apache.org</url>
<properties>
<java.home>${env.JAVA_HOME}</java.home>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.74.Final</version>
</dependency>
<!-- mysql驱动 -->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.38</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.7.8</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<archive>
<manifest>
<mainClass>com.cc.gw.MainApp</mainClass>
</manifest>
</archive>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
工具类
mysql工具,主要查询配置
package com.cc.gw.util;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SQLUtil {
private static String url = "jdbc:mysql://127.0.0.1:3306/serverdb?allowMultiQueries=true&useUnicode=true&characterEncoding=utf8&useSSL=false";
private static String username = "root";
private static String password = "123456";
/**
* 处理查询
* @param sqlStr 查询语句
* @return
*/
public static List<Map<String, Object>> query(String sqlStr) {
List<Map<String, Object>> list = new ArrayList<Map<String, Object>>();
try {
Connection con = DriverManager.getConnection(url, username, password);
Statement stmt = con.createStatement();
ResultSet rs = stmt.executeQuery(sqlStr);
if (null != rs) {
ResultSetMetaData md = rs.getMetaData(); //获得结果集结构信息,元数据
int columnCount = md.getColumnCount(); //获得列数
while (rs.next()) {
Map<String, Object> rowData = new HashMap<String, Object>();
for (int i = 1; i <= columnCount; i++) {
rowData.put(md.getColumnName(i), rs.getObject(i));
}
list.add(rowData);
}
try {
rs.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
if (stmt != null) { // 关闭声明
try {
stmt.close();
} catch (Exception e) {
e.printStackTrace();
}
}
if (con != null) { // 关闭连接对象
try {
con.close();
} catch (Exception e) {
e.printStackTrace();
}
}
} catch (Exception se) {
System.out.println("数据库连接失败!");
se.printStackTrace();
}
return list;
}
}
HTTP协议工具,支持http和websocket协议
package com.cc.gw.util;
import java.nio.charset.StandardCharsets;
import io.netty.util.internal.StringUtil;
public class HttpProtocolUtil {
private static final int MIN_REQUEST_LINE_LENGTH = 14; // 最小的请求行长度,例如 "GET / HTTP/1."
/**
* 获取请求的URI,根据URI去匹配后端服务列表,如果不是合法请求,返回null
* @param bytes 请求头字节数组
* @return
*/
public static String getRequestURI(byte[] bytes) {
if (null == bytes || bytes.length < MIN_REQUEST_LINE_LENGTH) {
return null;
}
String requestStr = new String(bytes, StandardCharsets.UTF_8);
String[] lines = requestStr.split("\r\n");
// 读取第一行
String request = lines[0];
String[] split = request.split(" ");
if (split.length != 3) {
return null;
}
String method = split[0];
if (!("GET".equals(method) || "POST".equals(method))) {
return null;
}
String version = split[2];
if (!version.startsWith("HTTP/1.")) {
return null;
}
String uri = split[1];
if (uri.startsWith("/") && requestStr.toLowerCase().contains("connection: upgrade")) {
return "ws://1.1" + uri;
}
return uri;
}
/**
* 获取请求的路径,根据路径去匹配后端服务列表
* @param path 整个请求path
* @return
*/
public static String getContextPath(String path) {
if (StringUtil.isNullOrEmpty(path)) {
return null;
}
int endLength = path.length();
if (path.contains("?")) {
endLength = path.indexOf("?");
}
String uri = null;
if (path.startsWith("ws")) {
uri = path.substring(path.indexOf("/", 7), endLength);
} else if (path.startsWith("/")) {
uri = path.substring(0, endLength);
}
if (null != uri) {
uri = uri.replaceAll("//+", "/");
return uri;
}
return null;
}
}
配置类
服务实体信息
package com.cc.gw.domain;
/**
* 服务信息
*/
public class RealServerEntity {
// 服务ID
private String id;
// 服务名称
private String name;
// 服务标识
private String code;
// 服务IP
private String ip;
// 服务端口
private Integer port;
// 服务IP类型(ipv4,ipv6)
private Integer ipType;
// 服务权重
private Double weight;
// 灰度信息
private String gray;
// 服务状态(运行中、掉线)
private Integer status;
// 操作状态(启用、禁用)
private Integer able;
// 排序
private Integer sno;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getCode() {
return code;
}
public void setCode(String code) {
this.code = code;
}
public String getIp() {
return ip;
}
public void setIp(String ip) {
this.ip = ip;
}
public Integer getPort() {
return port;
}
public void setPort(Integer port) {
this.port = port;
}
public Integer getIpType() {
return ipType;
}
public void setIpType(Integer ipType) {
this.ipType = ipType;
}
public Double getWeight() {
return weight;
}
public void setWeight(Double weight) {
this.weight = weight;
}
public String getGray() {
return gray;
}
public void setGray(String gray) {
this.gray = gray;
}
public Integer getStatus() {
return status;
}
public void setStatus(Integer status) {
this.status = status;
}
public Integer getAble() {
return able;
}
public void setAble(Integer able) {
this.able = able;
}
public Integer getSno() {
return sno;
}
public void setSno(Integer sno) {
this.sno = sno;
}
@Override
public String toString() {
return "RealServerEntity [id=" + id + ", name=" + name + ", code=" + code + ", ip=" + ip + ", port=" + port + ", ipType=" + ipType + ", weight=" + weight + ", gray=" + gray + ", status=" + status + ", able=" + able + ", sno=" + sno + "]";
}
}
常量
package com.cc.gw.config;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import com.cc.gw.domain.RealServerEntity;
import io.netty.channel.Channel;
import io.netty.util.AttributeKey;
public class Constant {
public static final int SERVER_PORT = 8888;
public static final Map<String, List<RealServerEntity>> ALL_SERVERS = new ConcurrentHashMap<>();
/** 绑定channel */
public static final AttributeKey<Channel> C = AttributeKey.newInstance("c");
/** 绑定协议类型 */
public static final AttributeKey<String> T = AttributeKey.newInstance("t");
}
服务配置接口
package com.cc.gw.service;
import java.util.List;
import com.cc.gw.domain.RealServerEntity;
public interface IRealServerService {
/**
* 获取所有在线服务
* @return
*/
List<RealServerEntity> getAllServers();
/**
* 根据服务名称获取服务信息
* @param serverCode 服务名称
* @return
*/
RealServerEntity getServer(String serverCode);
/**
* 根据服务名称和元数据获取服务信息
* @param serverCode 服务名称
* @param meta 元数据
* @return
*/
RealServerEntity getServer(String serverCode, String meta);
}
服务配置接口实现
package com.cc.gw.service.impl;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import com.cc.gw.config.Constant;
import com.cc.gw.domain.RealServerEntity;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.util.SQLUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.TypeReference;
import cn.hutool.json.JSONUtil;
public class RealServerServiceImpl implements IRealServerService {
@Override
public List<RealServerEntity> getAllServers() {
String sqlStr = "SELECT s.id, s.`name`, s.`code`, s.ip, s.`port`, s.ip_type, s.weight, s.`status`, s.able, s.gray, s.sno FROM `server` AS s WHERE s.able = 1 ORDER BY s.sno";
List<Map<String, Object>> query = SQLUtil.query(sqlStr);
if (CollectionUtil.isNotEmpty(query)) {
List<RealServerEntity> servers = JSONUtil.toBean(JSONUtil.toJsonStr(query), new TypeReference<List<RealServerEntity>>() {
}, true);
if (CollectionUtil.isNotEmpty(servers)) {
Map<String, List<RealServerEntity>> collect = servers.stream().collect(Collectors.groupingBy(RealServerEntity::getCode));
synchronized (Constant.ALL_SERVERS) {
Constant.ALL_SERVERS.clear();
collect.forEach((k, v) -> {
Constant.ALL_SERVERS.put(k, v);
});
}
}
return servers;
}
return null;
}
@Override
public RealServerEntity getServer(String serverCode) {
List<RealServerEntity> servers = getAllServers();
if (CollectionUtil.isNotEmpty(servers)) {
// TODO 负载均衡、灰度等策略
return servers.stream().filter(o -> serverCode.equals(o.getCode())).findFirst().orElse(null);
}
return null;
}
@Override
public RealServerEntity getServer(String serverCode, String meta) {
List<RealServerEntity> servers = getAllServers();
if (CollectionUtil.isNotEmpty(servers)) {
// TODO 负载均衡、灰度等策略
return servers.stream().filter(o -> serverCode.equals(o.getCode())).findFirst().orElse(null);
}
return null;
}
}
定时器任务,定时查询数据库获取最新的配置信息
package com.cc.gw.config;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.service.impl.RealServerServiceImpl;
import io.netty.channel.Channel;
public class ScheduledTasks {
private static final ScheduledExecutorService reconnectExecutor = Executors.newSingleThreadScheduledExecutor();
public static Channel refreshServerList() throws Exception {
reconnectExecutor.scheduleAtFixedRate(new Runnable() {
@Override
public void run() {
try {
IRealServerService realServerService = new RealServerServiceImpl();
realServerService.getAllServers();
} catch (Exception e) {
e.printStackTrace();
}
}
}, 3, 5, TimeUnit.SECONDS);
return null;
}
}
代理服务类
真实服务处理类
package com.cc.gw.socket;
import com.cc.gw.config.Constant;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
public class RealServerHandler extends SimpleChannelInboundHandler<ByteBuf> {
public Channel proxyChannel;
public RealServerHandler(Channel proxyChannel) {
this.proxyChannel = proxyChannel;
}
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) {
byte[] bytes = new byte[buf.readableBytes()];
buf.readBytes(bytes);
ByteBuf byteBuf = ctx.alloc().buffer(bytes.length);
byteBuf.writeBytes(bytes);
proxyChannel.writeAndFlush(byteBuf);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
proxyChannel.attr(Constant.C).set(null);
}
}
反向代理处理类,实现整个反向代理的主要功能
package com.cc.gw.socket;
import java.util.List;
import java.util.stream.Collectors;
import com.cc.gw.config.Constant;
import com.cc.gw.domain.RealServerEntity;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.service.impl.RealServerServiceImpl;
import com.cc.gw.util.HttpProtocolUtil;
import cn.hutool.core.collection.CollectionUtil;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.internal.StringUtil;
public class ReverseProxyHandler extends SimpleChannelInboundHandler<ByteBuf> {
static EventLoopGroup eventLoopGroup = new NioEventLoopGroup();
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
byte[] bytes = new byte[buf.readableBytes()];
buf.readBytes(bytes);
Channel realChannel = ctx.channel().attr(Constant.C).get();
String protocol = ctx.channel().attr(Constant.T).get();
if (null != realChannel && realChannel.isActive() && "ws".equals(protocol)) {
sendMsg(ctx, bytes, realChannel);
return;
}
String uri = HttpProtocolUtil.getRequestURI(bytes);
if (StringUtil.isNullOrEmpty(uri)) {
ctx.close();
return;
}
String contextPath = HttpProtocolUtil.getContextPath(uri);
if (StringUtil.isNullOrEmpty(contextPath)) {
ctx.close();
return;
}
checkServers();
List<String> collect = Constant.ALL_SERVERS.keySet().stream().filter(o -> contextPath.startsWith(o)).sorted().collect(Collectors.toList());
String serverCode = null;
if (CollectionUtil.isNotEmpty(collect)) {
serverCode = collect.get(collect.size() - 1);
}
if (StringUtil.isNullOrEmpty(serverCode)) {
ctx.close();
return;
}
if (uri.startsWith("ws")) {
// 长连接
if (null == realChannel || !realChannel.isActive()) {
createSocket(ctx, bytes, serverCode);
ctx.channel().attr(Constant.T).set("ws");
} else {
sendMsg(ctx, bytes, realChannel);
}
} else {
// 短连接
// TODO 短时间内可复用链接
createSocket(ctx, bytes, serverCode);
}
}
/**
* 创建一个socket链接
* @param ctx 当前会话
* @param bytes 向后台服务传输的数据
* @param serverCode 后台服务CODE
*/
private void createSocket(ChannelHandlerContext ctx, byte[] bytes, String serverCode) {
try {
List<RealServerEntity> list = Constant.ALL_SERVERS.get(serverCode);
RealServerEntity server = list.get(0);
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new RealServerHandler(ctx.channel()));
}
});
bootstrap.connect(server.getIp(), server.getPort()).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
Channel channel = future.channel();
ctx.channel().attr(Constant.C).set(channel);
sendMsg(ctx, bytes, channel);
}
}
});
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 如果还未初始化服务列表,进行初始化
*/
private void checkServers() {
if (Constant.ALL_SERVERS.isEmpty()) {
IRealServerService realServerService = new RealServerServiceImpl();
realServerService.getAllServers();
}
}
private void sendMsg(ChannelHandlerContext ctx, byte[] bytes, Channel channel) {
ByteBuf byteBuf = ctx.alloc().buffer(bytes.length);
byteBuf.writeBytes(bytes);
channel.writeAndFlush(byteBuf);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
try {
Channel channel = ctx.channel().attr(Constant.C).get();
if (null != channel && channel.isActive()) {
channel.close();
}
ctx.channel().attr(Constant.C).set(null);
} catch (Exception e) {
e.printStackTrace();
}
}
}
反向代理服务类
package com.cc.gw.socket;
import com.cc.gw.config.Constant;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
public class ReverseProxySocket {
/**
* 启动服务代理
*
* @throws Exception
*/
public static void startServer() throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new ReverseProxyHandler());
}
});
ChannelFuture f = b.bind(Constant.SERVER_PORT).sync();
f.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
bossGroup.shutdownGracefully();
}
}
}
启动类
package com.cc.gw;
import com.cc.gw.config.ScheduledTasks;
import com.cc.gw.socket.ReverseProxySocket;
public class MainApp {
public static void main(String[] args) throws Exception {
ScheduledTasks.refreshServerList();
ReverseProxySocket.startServer();
}
}
使用
启动服务端
java -jar cc-gateway.jar