前言
简易网关实现。
Filters
CorsFilter
/**
* 设置跨域
*/
@Order(0)
@Component
@WebFilter(filterName = "CorsFilter", urlPatterns = {"/*"})
public class CorsFilter implements Filter {
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
// 跨域设置
response.setHeader("Access-Control-Allow-Methods", "*");
response.setHeader("Access-Control-Allow-Credentials", "true");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Headers", "Content-Type,Authorization");
response.setHeader("Access-Control-Expose-Headers", "*");
filterChain.doFilter(request, response);
}
}
ResponseFilter
/**
* 设置响应
*/
@Order(1)
@Component
@WebFilter(filterName = "ResponseFilter", urlPatterns = {"/*"})
public class ResponseFilter implements Filter {
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
if (Objects.equals(request.getMethod(), HttpMethod.OPTIONS.name())) {
response.setStatus(200);
} else {
// 数据量大于8MB
if (request.getContentLengthLong() > 8 * 1024 * 1024) {
ResponseData responseData = new ResponseData(400, "数据过大");
byte[] bytes = JSON.toJSONBytes(responseData);
request.setAttribute("bytes", bytes);
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
} else {
filterChain.doFilter(request, response);
}
byte[] bytes = (byte[]) request.getAttribute("bytes");
OutputStream out = response.getOutputStream();
out.write(bytes);
out.flush();
out.close();
}
}
}
AuthorizationFilter
/**
* 用于验证权限
*/
@Order(2)
@Component
@WebFilter(filterName = "AuthorizationFilter", urlPatterns = {"/*"})
public class AuthorizationFilter implements Filter {
@Autowired
private JwtTokenUtil jwtTokenUtil;
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
String uri = request.getRequestURI();
if (!uri.contains("/login_and_register/login")
&& !uri.contains("/login_and_register/register")
&& !uri.contains("/login_and_register/code")
&& !uri.contains("/file/pic")) {
// 检验token
String authorization = request.getHeader("Authorization");
if (authorization == null || !jwtTokenUtil.validateToken(authorization)) {
// 消息
ResponseData responseData = new ResponseData(401, "未登录");
byte[] bytes = JSON.toJSONBytes(responseData);
request.setAttribute("bytes", bytes);
response.setContentType("application/json");
response.setStatus(401);
return;
}
}
filterChain.doFilter(request, response);
}
}
RoutingFilter
/**
* 路由转发
*/
@Order(3)
@Component
@WebFilter(filterName = "RoutingFilter", urlPatterns = {"/*"})
public class RoutingFilter implements Filter {
@Autowired
private Router router;
@Autowired
private RestTemplate restTemplate;
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
ResponseData responseData = null;
String uri = request.getRequestURI();
Service service = getService(router.getServices(), uri);
if (service == null) {
responseData = new ResponseData(404, "不存在的服务");
response.setStatus(404);
} else {
String url = buildUrl(request, service.getNextIp(), service.getPath());
try {
RequestEntity<byte[]> requestEntity = buildRequestEntity(request, url);
ResponseEntity<byte[]> exchange = restTemplate.exchange(requestEntity, byte[].class);
byte[] bytes = exchange.getBody();
if (bytes != null) {
request.setAttribute("bytes", bytes);
response.setContentType(exchange.getHeaders().getContentType() != null
? exchange.getHeaders().getContentType().toString()
: MediaType.APPLICATION_JSON_VALUE);
return;
} else {
responseData = new ResponseData(404, "资源不存在");
}
} catch (URISyntaxException exception) {
responseData = new ResponseData(500, "服务器故障,uri创建失败");
response.setStatus(500);
} catch (ResourceAccessException exception) {
responseData = new ResponseData(500, "响应超时");
response.setStatus(500);
} catch (HttpClientErrorException exception) {
responseData = new ResponseData(exception.getRawStatusCode(), exception.getMessage());
response.setStatus(exception.getRawStatusCode());
}
}
byte[] bytes = JSON.toJSONBytes(responseData);
request.setAttribute("bytes", bytes);
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
}
/**
* 根据请求中的部分参数获取对应的Service对象
* @param services 服务列表
* @param uri 原始uri
* @return Service对象
*/
private Service getService(List<Service> services, String uri) {
Service service = null;
for (Service temp : services) {
if (temp.support(uri)) {
service = temp;
break;
}
}
return service;
}
/**
* 轮询ip地址
* @param service 非空的Service对象
* @return ip地址
*/
private String getIp(Service service) {
int index = service.getCounter().getAndIncrement();
return service.getIp()[index % service.size()];
}
/**
* 构造重定向的url
* @param request 原始请求
* @param ip 服务器ip
* @param path 原始请求中的部分路径
* @return 重定向的url
*/
private String buildUrl(HttpServletRequest request, String ip, String path) {
String query = request.getQueryString();
return ip + request.getRequestURI().replace(path, "") +
(query != null ? "?" + query : "");
}
/**
* 创建RequestEntity<byte[]>
* @param request 原始请求
* @param url 重定向的url
* @return RequestEntity\<byte[]\>对象
* @throws IOException IO异常
* @throws URISyntaxException 创建URI时抛出的异常
*/
private RequestEntity<byte[]> buildRequestEntity(HttpServletRequest request, String url) throws IOException, URISyntaxException {
HttpMethod method = HttpMethod.resolve(request.getMethod());
HttpHeaders header = parseRequestHeaders(request);
byte[] body = parseRequestBody(request);
return new RequestEntity<byte[]>(body, header, method, new URI(url));
}
/**
* 构建请求头
* @param request 原始请求
* @return 请求头对象
*/
private HttpHeaders parseRequestHeaders(HttpServletRequest request) {
// 构建请求头
HttpHeaders headers = new HttpHeaders();
Enumeration<String> headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String name = headerNames.nextElement();
String value = request.getHeader(name);
headers.add(name, value);
}
return headers;
}
/**
* 构建请求体
* @param request 原始请求
* @return byte[]数组
* @throws IOException IO异常
*/
private byte[] parseRequestBody(HttpServletRequest request)
throws IOException {
InputStream input = request.getInputStream();
return StreamUtils.copyToByteArray(input);
}
private byte[] read(InputStream input) {
ByteArrayOutputStream output = null;
try {
output = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int read = -1;
while ((read = input.read(buffer)) != -1) {
output.write(buffer, 0, read);
}
return output.toByteArray();
} catch (IOException exception) {
System.out.println(exception.getMessage());
return null;
} finally {
try {
if (input != null) {
input.close();
}
if (output != null) {
output.close();
}
} catch (IOException exception) {
System.out.println(exception.getMessage());
}
}
}
}
LoadBalancer
接口
public interface LoadBalancer {
int getNextService(int length);
}
默认实现
/**
* 默认的负载均衡处理器
* 采用轮询方法
*/
public class DefaultLoadBalancer implements LoadBalancer {
private final AtomicCounter counter;
public DefaultLoadBalancer() {
counter = new AtomicCounter();
}
@Override
public int getNextService(int length) {
return counter.getAndIncrement() % length;
}
}
Service
public class Service {
private String serviceName;
private String path;
private String[] ip;
private final AtomicCounter counter = new AtomicCounter();
private final LoadBalancer balancer = new DefaultLoadBalancer();
public Service() {}
public Service(String serviceName, String path, String[] ip) {
this.serviceName = serviceName;
this.path = path;
this.ip = ip;
}
public String getServiceName() {
return serviceName;
}
public void setServiceName(String serviceName) {
this.serviceName = serviceName;
}
public String getPath() {
return path;
}
public void setPath(String path) {
this.path = path;
}
public String[] getIp() {
return ip;
}
public void setIp(String[] ip) {
this.ip = ip;
}
public AtomicCounter getCounter() {
return counter;
}
public int size() {
return this.ip.length;
}
public boolean support(String uri) {
return uri.startsWith(this.path);
}
public String getNextIp() {
int index = balancer.getNextService(this.ip.length);
return this.ip[index];
}
@Override
public String toString() {
return "Service{" +
"serviceName='" + serviceName + '\'' +
", path='" + path + '\'' +
", ip=" + Arrays.toString(ip) +
", counter=" + counter +
'}';
}
}