JAVA+Netty简单实现Nginx反向代理网关功能【设计实践】

背景

本文实现一个类似于nginx或gateway的反向代理网关,实现思路是访客通过网络请求反向代理服务,代理服务连接到真实服务,维护访客和真实服务的数据交互。

这个实现和之前的内网穿透项目思路相似,只不过内网穿透是由客户端主动跟代理服务维护连接,这个是代理服务主动和真实服务连接。

本文用MySQL做真实服务的配置功能,只实现了配置IP、端口和请求路径的配置。

实践

项目结构

1、cc-gateway项目:实现反向代理功能

原理分析

反向代理的实现过程主要分两步

1、启动服务端,这时代理服务监听8888端口(默认8888)

2、访客通过访问代理服务8888端口(例如http://127.0.0.1:8888/sso),代理服务接收到请求后解析请求路径得到(/sso),根据这个路径查询数据库配置,如果匹配到(/sso)对应的真实服务的IP和端口,那么代理服务会发起与真实服务的连接,并建立访客和真实服务的数据传输通道。

这两步最终形成了(访客-代理-真实服务)完整的通道。

代码实现

数据库设计

建表

CREATE TABLE `server` (
  `id` varchar(32) NOT NULL COMMENT '主键ID',
  `name` varchar(255) DEFAULT NULL COMMENT '服务名称',
  `code` varchar(255) DEFAULT NULL COMMENT '服务标识',
  `ip` varchar(255) DEFAULT NULL COMMENT '服务IP',
  `port` int(11) DEFAULT NULL COMMENT '服务端口',
  `ip_type` varchar(255) DEFAULT NULL COMMENT '服务IP类型(ipv4,ipv6)',
  `weight` double DEFAULT NULL COMMENT '服务权重',
  `status` int(1) DEFAULT NULL COMMENT '服务状态(运行中、掉线)',
  `able` int(1) DEFAULT NULL COMMENT '操作状态(启用、禁用)',
  `gray` varchar(255) DEFAULT NULL COMMENT '灰度信息',
  `sno` int(11) DEFAULT NULL COMMENT '排序',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='服务信息表';

填点数据 

### 数据以自己真实服务的配置修改
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('1', 'im', '/im', '10.0.0.2', 8889, '1', 1, 1, 1, '', 1);
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('2', '/', '/', '10.0.0.2', 8886, '1', 1, 1, 1, '', 1);
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('3', 'sso', '/sso', '10.0.0.3', 8885, '1', 1, 1, 1, '', 1);

cc-gateway项目

pom文件
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <groupId>com.cc</groupId>
    <artifactId>cc-gateway</artifactId>
    <version>1.0-SNAPSHOT</version>
    <name>cc-gateway</name>
    <url>http://maven.apache.org</url>

    <properties>
        <java.home>${env.JAVA_HOME}</java.home>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <java.version>1.8</java.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
		<!-- mysql驱动 -->
		<dependency>
			<groupId>mysql</groupId>
			<artifactId>mysql-connector-java</artifactId>
			<version>5.1.38</version>
		</dependency>
		<dependency>
			<groupId>cn.hutool</groupId>
			<artifactId>hutool-all</artifactId>
			<version>5.7.8</version>
		</dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <artifactId>maven-assembly-plugin</artifactId>
                <version>3.0.0</version>
                <configuration>
                    <archive>
                        <manifest>
                            <mainClass>com.cc.gw.MainApp</mainClass>
                        </manifest>
                    </archive>
                    <descriptorRefs>
                        <descriptorRef>jar-with-dependencies</descriptorRef>
                    </descriptorRefs>
                </configuration>
                <executions>
                    <execution>
                        <id>make-assembly</id>
                        <phase>package</phase>
                        <goals>
                            <goal>single</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>
工具类

mysql工具,主要查询配置

package com.cc.gw.util;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SQLUtil {
	private static String url = "jdbc:mysql://127.0.0.1:3306/serverdb?allowMultiQueries=true&useUnicode=true&characterEncoding=utf8&useSSL=false";
	private static String username = "root";
	private static String password = "123456";

	/**
	 * 处理查询
	 * @param sqlStr 查询语句
	 * @return
	 */
	public static List<Map<String, Object>> query(String sqlStr) {
		List<Map<String, Object>> list = new ArrayList<Map<String, Object>>();
		try {
			Connection con = DriverManager.getConnection(url, username, password);
			Statement stmt = con.createStatement();
			ResultSet rs = stmt.executeQuery(sqlStr);
			if (null != rs) {
				ResultSetMetaData md = rs.getMetaData(); //获得结果集结构信息,元数据
				int columnCount = md.getColumnCount(); //获得列数 
				while (rs.next()) {
					Map<String, Object> rowData = new HashMap<String, Object>();
					for (int i = 1; i <= columnCount; i++) {
						rowData.put(md.getColumnName(i), rs.getObject(i));
					}
					list.add(rowData);
				}
				try {
					rs.close();
				} catch (SQLException e) {
					e.printStackTrace();
				}
			}
			if (stmt != null) { // 关闭声明    
				try {
					stmt.close();
				} catch (Exception e) {
					e.printStackTrace();
				}
			}
			if (con != null) { // 关闭连接对象    
				try {
					con.close();
				} catch (Exception e) {
					e.printStackTrace();
				}
			}
		} catch (Exception se) {
			System.out.println("数据库连接失败!");
			se.printStackTrace();
		}
		return list;
	}
}

HTTP协议工具,支持http和websocket协议

package com.cc.gw.util;

import java.nio.charset.StandardCharsets;

import io.netty.util.internal.StringUtil;

public class HttpProtocolUtil {
	private static final int MIN_REQUEST_LINE_LENGTH = 14; // 最小的请求行长度,例如 "GET / HTTP/1." 

	/**
	 * 获取请求的URI,根据URI去匹配后端服务列表,如果不是合法请求,返回null
	 * @param bytes 请求头字节数组
	 * @return
	 */
	public static String getRequestURI(byte[] bytes) {
		if (null == bytes || bytes.length < MIN_REQUEST_LINE_LENGTH) {
			return null;
		}
		String requestStr = new String(bytes, StandardCharsets.UTF_8);
		String[] lines = requestStr.split("\r\n");
		// 读取第一行  
		String request = lines[0];
		String[] split = request.split(" ");
		if (split.length != 3) {
			return null;
		}
		String method = split[0];
		if (!("GET".equals(method) || "POST".equals(method))) {
			return null;
		}
		String version = split[2];
		if (!version.startsWith("HTTP/1.")) {
			return null;
		}
		String uri = split[1];
		if (uri.startsWith("/") && requestStr.toLowerCase().contains("connection: upgrade")) {
			return "ws://1.1" + uri;
		}
		return uri;
	}

	/**
	 * 获取请求的路径,根据路径去匹配后端服务列表
	 * @param path 整个请求path
	 * @return
	 */
	public static String getContextPath(String path) {
		if (StringUtil.isNullOrEmpty(path)) {
			return null;
		}

		int endLength = path.length();
		if (path.contains("?")) {
			endLength = path.indexOf("?");
		}
		String uri = null;
		if (path.startsWith("ws")) {
			uri = path.substring(path.indexOf("/", 7), endLength);
		} else if (path.startsWith("/")) {
			uri = path.substring(0, endLength);
		}
		if (null != uri) {
			uri = uri.replaceAll("//+", "/");
			return uri;
		}
		return null;
	}
}
配置类

服务实体信息

package com.cc.gw.domain;

/**
 * 服务信息
 */
public class RealServerEntity {

	// 服务ID
	private String id;
	// 服务名称
	private String name;
	// 服务标识
	private String code;
	// 服务IP
	private String ip;
	// 服务端口
	private Integer port;
	// 服务IP类型(ipv4,ipv6)
	private Integer ipType;
	// 服务权重
	private Double weight;
	// 灰度信息
	private String gray;
	// 服务状态(运行中、掉线)
	private Integer status;
	// 操作状态(启用、禁用)
	private Integer able;
	// 排序
	private Integer sno;

	public String getId() {
		return id;
	}

	public void setId(String id) {
		this.id = id;
	}

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}

	public String getCode() {
		return code;
	}

	public void setCode(String code) {
		this.code = code;
	}

	public String getIp() {
		return ip;
	}

	public void setIp(String ip) {
		this.ip = ip;
	}

	public Integer getPort() {
		return port;
	}

	public void setPort(Integer port) {
		this.port = port;
	}

	public Integer getIpType() {
		return ipType;
	}

	public void setIpType(Integer ipType) {
		this.ipType = ipType;
	}

	public Double getWeight() {
		return weight;
	}

	public void setWeight(Double weight) {
		this.weight = weight;
	}

	public String getGray() {
		return gray;
	}

	public void setGray(String gray) {
		this.gray = gray;
	}

	public Integer getStatus() {
		return status;
	}

	public void setStatus(Integer status) {
		this.status = status;
	}

	public Integer getAble() {
		return able;
	}

	public void setAble(Integer able) {
		this.able = able;
	}

	public Integer getSno() {
		return sno;
	}

	public void setSno(Integer sno) {
		this.sno = sno;
	}

	@Override
	public String toString() {
		return "RealServerEntity [id=" + id + ", name=" + name + ", code=" + code + ", ip=" + ip + ", port=" + port + ", ipType=" + ipType + ", weight=" + weight + ", gray=" + gray + ", status=" + status + ", able=" + able + ", sno=" + sno + "]";
	}

}

常量

package com.cc.gw.config;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.cc.gw.domain.RealServerEntity;

import io.netty.channel.Channel;
import io.netty.util.AttributeKey;

public class Constant {

	public static final int SERVER_PORT = 8888;

	public static final Map<String, List<RealServerEntity>> ALL_SERVERS = new ConcurrentHashMap<>();

	/** 绑定channel */
	public static final AttributeKey<Channel> C = AttributeKey.newInstance("c");

	/** 绑定协议类型 */
	public static final AttributeKey<String> T = AttributeKey.newInstance("t");
}

服务配置接口

package com.cc.gw.service;

import java.util.List;

import com.cc.gw.domain.RealServerEntity;

public interface IRealServerService {

	/**
	 * 获取所有在线服务
	 * @return
	 */
	List<RealServerEntity> getAllServers();

	/**
	 * 根据服务名称获取服务信息
	 * @param serverCode 服务名称
	 * @return
	 */
	RealServerEntity getServer(String serverCode);

	/**
	 * 根据服务名称和元数据获取服务信息
	 * @param serverCode 服务名称
	 * @param meta 元数据
	 * @return
	 */
	RealServerEntity getServer(String serverCode, String meta);
}

 服务配置接口实现

package com.cc.gw.service.impl;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.cc.gw.config.Constant;
import com.cc.gw.domain.RealServerEntity;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.util.SQLUtil;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.TypeReference;
import cn.hutool.json.JSONUtil;

public class RealServerServiceImpl implements IRealServerService {

	@Override
	public List<RealServerEntity> getAllServers() {
		String sqlStr = "SELECT s.id, s.`name`, s.`code`, s.ip, s.`port`, s.ip_type, s.weight, s.`status`, s.able, s.gray, s.sno FROM `server` AS s WHERE s.able = 1 ORDER BY s.sno";
		List<Map<String, Object>> query = SQLUtil.query(sqlStr);
		if (CollectionUtil.isNotEmpty(query)) {
			List<RealServerEntity> servers = JSONUtil.toBean(JSONUtil.toJsonStr(query), new TypeReference<List<RealServerEntity>>() {
			}, true);
			if (CollectionUtil.isNotEmpty(servers)) {
				Map<String, List<RealServerEntity>> collect = servers.stream().collect(Collectors.groupingBy(RealServerEntity::getCode));
				synchronized (Constant.ALL_SERVERS) {
					Constant.ALL_SERVERS.clear();
					collect.forEach((k, v) -> {
						Constant.ALL_SERVERS.put(k, v);
					});
				}
			}

			return servers;
		}
		return null;
	}

	@Override
	public RealServerEntity getServer(String serverCode) {
		List<RealServerEntity> servers = getAllServers();
		if (CollectionUtil.isNotEmpty(servers)) {
			// TODO 负载均衡、灰度等策略
			return servers.stream().filter(o -> serverCode.equals(o.getCode())).findFirst().orElse(null);
		}
		return null;
	}

	@Override
	public RealServerEntity getServer(String serverCode, String meta) {
		List<RealServerEntity> servers = getAllServers();
		if (CollectionUtil.isNotEmpty(servers)) {
			// TODO 负载均衡、灰度等策略
			return servers.stream().filter(o -> serverCode.equals(o.getCode())).findFirst().orElse(null);
		}
		return null;
	}

}

定时器任务,定时查询数据库获取最新的配置信息

package com.cc.gw.config;

import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import com.cc.gw.service.IRealServerService;
import com.cc.gw.service.impl.RealServerServiceImpl;

import io.netty.channel.Channel;

public class ScheduledTasks {

	private static final ScheduledExecutorService reconnectExecutor = Executors.newSingleThreadScheduledExecutor();

	public static Channel refreshServerList() throws Exception {
		reconnectExecutor.scheduleAtFixedRate(new Runnable() {
			@Override
			public void run() {
				try {
					IRealServerService realServerService = new RealServerServiceImpl();
					realServerService.getAllServers();
				} catch (Exception e) {
					e.printStackTrace();
				}
			}
		}, 3, 5, TimeUnit.SECONDS);
		return null;
	}
}
代理服务类

真实服务处理类

package com.cc.gw.socket;

import com.cc.gw.config.Constant;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;

public class RealServerHandler extends SimpleChannelInboundHandler<ByteBuf> {

	public Channel proxyChannel;

	public RealServerHandler(Channel proxyChannel) {
		this.proxyChannel = proxyChannel;
	}

	@Override
	public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) {
		byte[] bytes = new byte[buf.readableBytes()];
		buf.readBytes(bytes);
		ByteBuf byteBuf = ctx.alloc().buffer(bytes.length);
		byteBuf.writeBytes(bytes);
		proxyChannel.writeAndFlush(byteBuf);
	}

	@Override
	public void channelInactive(ChannelHandlerContext ctx) throws Exception {
		super.channelInactive(ctx);
		proxyChannel.attr(Constant.C).set(null);
	}
}

反向代理处理类,实现整个反向代理的主要功能

package com.cc.gw.socket;

import java.util.List;
import java.util.stream.Collectors;

import com.cc.gw.config.Constant;
import com.cc.gw.domain.RealServerEntity;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.service.impl.RealServerServiceImpl;
import com.cc.gw.util.HttpProtocolUtil;

import cn.hutool.core.collection.CollectionUtil;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.internal.StringUtil;

public class ReverseProxyHandler extends SimpleChannelInboundHandler<ByteBuf> {
	static EventLoopGroup eventLoopGroup = new NioEventLoopGroup();

	@Override
	public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
		byte[] bytes = new byte[buf.readableBytes()];
		buf.readBytes(bytes);
		Channel realChannel = ctx.channel().attr(Constant.C).get();
		String protocol = ctx.channel().attr(Constant.T).get();
		if (null != realChannel && realChannel.isActive() && "ws".equals(protocol)) {
			sendMsg(ctx, bytes, realChannel);
			return;
		}
		String uri = HttpProtocolUtil.getRequestURI(bytes);
		if (StringUtil.isNullOrEmpty(uri)) {
			ctx.close();
			return;
		}
		String contextPath = HttpProtocolUtil.getContextPath(uri);
		if (StringUtil.isNullOrEmpty(contextPath)) {
			ctx.close();
			return;
		}
		checkServers();
		List<String> collect = Constant.ALL_SERVERS.keySet().stream().filter(o -> contextPath.startsWith(o)).sorted().collect(Collectors.toList());
		String serverCode = null;
		if (CollectionUtil.isNotEmpty(collect)) {
			serverCode = collect.get(collect.size() - 1);
		}
		if (StringUtil.isNullOrEmpty(serverCode)) {
			ctx.close();
			return;
		}
		if (uri.startsWith("ws")) {
			// 长连接
			if (null == realChannel || !realChannel.isActive()) {
				createSocket(ctx, bytes, serverCode);
				ctx.channel().attr(Constant.T).set("ws");
			} else {
				sendMsg(ctx, bytes, realChannel);
			}
		} else {
			// 短连接
			// TODO 短时间内可复用链接
			createSocket(ctx, bytes, serverCode);
		}
	}

	/**
	 * 创建一个socket链接
	 * @param ctx 当前会话
	 * @param bytes 向后台服务传输的数据
	 * @param serverCode 后台服务CODE
	 */
	private void createSocket(ChannelHandlerContext ctx, byte[] bytes, String serverCode) {
		try {
			List<RealServerEntity> list = Constant.ALL_SERVERS.get(serverCode);
			RealServerEntity server = list.get(0);
			Bootstrap bootstrap = new Bootstrap();
			bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() {
				@Override
				public void initChannel(SocketChannel ch) throws Exception {
					ChannelPipeline pipeline = ch.pipeline();
					pipeline.addLast(new RealServerHandler(ctx.channel()));
				}

			});
			bootstrap.connect(server.getIp(), server.getPort()).addListener(new ChannelFutureListener() {
				@Override
				public void operationComplete(ChannelFuture future) throws Exception {
					if (future.isSuccess()) {
						Channel channel = future.channel();
						ctx.channel().attr(Constant.C).set(channel);
						sendMsg(ctx, bytes, channel);
					}
				}
			});
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	/**
	 * 如果还未初始化服务列表,进行初始化
	 */
	private void checkServers() {
		if (Constant.ALL_SERVERS.isEmpty()) {
			IRealServerService realServerService = new RealServerServiceImpl();
			realServerService.getAllServers();
		}
	}

	private void sendMsg(ChannelHandlerContext ctx, byte[] bytes, Channel channel) {
		ByteBuf byteBuf = ctx.alloc().buffer(bytes.length);
		byteBuf.writeBytes(bytes);
		channel.writeAndFlush(byteBuf);
	}

	@Override
	public void channelInactive(ChannelHandlerContext ctx) throws Exception {
		super.channelInactive(ctx);
		try {
			Channel channel = ctx.channel().attr(Constant.C).get();
			if (null != channel && channel.isActive()) {
				channel.close();
			}
			ctx.channel().attr(Constant.C).set(null);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
}

反向代理服务类

package com.cc.gw.socket;

import com.cc.gw.config.Constant;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;

public class ReverseProxySocket {

	/**
	 * 启动服务代理
	 * 
	 * @throws Exception
	 */
	public static void startServer() throws Exception {
		EventLoopGroup bossGroup = new NioEventLoopGroup();
		EventLoopGroup workerGroup = new NioEventLoopGroup();
		try {
			ServerBootstrap b = new ServerBootstrap();
			b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {
				@Override
				public void initChannel(SocketChannel ch) throws Exception {
					ChannelPipeline pipeline = ch.pipeline();
					pipeline.addLast(new ReverseProxyHandler());
				}
			});

			ChannelFuture f = b.bind(Constant.SERVER_PORT).sync();
			f.channel().closeFuture().sync();
		} finally {
			workerGroup.shutdownGracefully();
			bossGroup.shutdownGracefully();
		}
	}

}
启动类
package com.cc.gw;

import com.cc.gw.config.ScheduledTasks;
import com.cc.gw.socket.ReverseProxySocket;

public class MainApp {

	public static void main(String[] args) throws Exception {
		ScheduledTasks.refreshServerList();
		ReverseProxySocket.startServer();
	}
}

使用

启动服务端
java -jar cc-gateway.jar
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值