实现接口转发Filter
package com.quanzhi.api.operation.common.web.filter;
import com.alibaba.fastjson.JSON;
import com.google.common.collect.Lists;
import com.quanzhi.api.operation.biz.service.impl.SysConfigServiceImpl;
import com.quanzhi.api.operation.common.util.dto.MonitorConfigDto;
import com.quanzhi.api.operation.common.util.utils.DataUtil;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.*;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestTemplate;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
@WebFilter
@Component
@Log4j2
@Order(3)
public class MonitorForwardFilter implements Filter {
@Autowired
private SysConfigServiceImpl sysConfigService;
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException{
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse resp = (HttpServletResponse) response;
String requestUrl = req.getRequestURI();
List<String> urls = getUrls();
if (urls.contains(requestUrl)) {
String url = manageUrl(req);
log.info("MonitorForwardFilter is forwarding " + url);
forwardUrl(url, req, resp);
} else {
chain.doFilter(request,response);
}
}
private List<String> getUrls() {
Object configByName = sysConfigService.getConfigByName("forwardList");
if (DataUtil.isNotEmpty(configByName)) {
return JSON.parseArray(JSON.toJSONString(configByName), String.class);
}
return Lists.newArrayList();
}
private void forwardUrl(String url, HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException {
String method = req.getMethod();
HttpMethod httpMethod = HttpMethod.resolve(method);
MultiValueMap<String, String> headers = parseRequestHeader(req);
byte[] body = parseRequestBody(req);
RequestEntity requestEntity = new RequestEntity(body, headers, httpMethod, URI.create(url));
RestTemplate restTemplate = new RestTemplate();
restTemplate.getMessageConverters().set(1,new StringHttpMessageConverter(StandardCharsets.UTF_8));
ResponseEntity<String> result = restTemplate.exchange(requestEntity, String.class);
String resultBody = result.getBody();
HttpHeaders resultHeaders = result.getHeaders();
MediaType contentType = resultHeaders.getContentType();
if (contentType != null) {
resp.setContentType(contentType.toString());
}
resp.setCharacterEncoding("UTF-8");
PrintWriter writer = resp.getWriter();
writer.write(resultBody);
writer.flush();
}
private String manageUrl(HttpServletRequest req) {
String requestUrl = req.getRequestURI();
String queryString = req.getQueryString();
if (queryString != null) {
requestUrl = requestUrl + "?" +queryString;
}
MonitorConfigDto monitorConfigDto = getMonitorIp();
String[] split = requestUrl.split("/app-operation");
String url = "";
try {
if (DataUtil.isNotEmpty(split)) {
url = split[1];
}
}catch (IndexOutOfBoundsException e) {
url = "";
}
return "http://" + monitorConfigDto.getIp() + ":"+ monitorConfigDto.getPort() + url;
}
private MonitorConfigDto getMonitorIp() {
Object monitorConfig = sysConfigService.getConfigByName("monitorConfig");
if (DataUtil.isNotEmpty(monitorConfig)) {
return JSON.parseObject(JSON.toJSONString(monitorConfig), MonitorConfigDto.class);
}
return new MonitorConfigDto();
}
private byte[] parseRequestBody(HttpServletRequest request) throws IOException {
InputStream inputStream = request.getInputStream();
return StreamUtils.copyToByteArray(inputStream);
}
private MultiValueMap<String, String> parseRequestHeader(HttpServletRequest request) {
HttpHeaders httpHeaders = new HttpHeaders();
List<String> headerNames = Collections.list(request.getHeaderNames());
for(String headerName : headerNames) {
List<String> headerValues = Collections.list(request.getHeaders(headerName));
for (String headerValue : headerValues) {
httpHeaders.add(headerName, headerValue);
}
}
return httpHeaders;
}
}