业务中遇到了一些场景,需要对请求做统一拦截,用请求参数计算新的变量设置到请求头中。
以下分别用Filter和Interceptor两种方式实现,(建议使用Filter的方法,因为Interceptor的方法仅仅对Post和GET方法有效,并不支持PUT等其他方法,主要原因是因为HttpServletRequest接口的实现类不同,以下仅支持了POST和GET方法)
场景:
我们有两类用户,一类用户的请求中头部header中有用户名字参数(userName),另一类用户的请求并无请求header,但是请求参数requestParam中有userId参数,我们可以通过userId查库等计算出userName。
需求:
我们想统一两类用户,使请求达到我们的controller层的时候头部统一都有userName。
具体:
以下的两种方式实现了这个需求,controller层获取userName,将优先使用从requestParam中的userId计算出来的userName。
1、Filter实现
package com.paomo.filter;
import org.springframework.context.annotation.Configuration;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.*;
@Configuration
@WebFilter(value = "/*")
public class ParamFilter implements Filter {
private static final String USER_PARAM = "userId";
private static final String USER_HEADER = "userName";
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
HeaderMapRequestWrapper requestWrapper = new HeaderMapRequestWrapper(req);
String userId = request.getParameter(USER_PARAM);
//依据你自己的业务通过userId获取userName
String userName = getUserNameById(userId);
requestWrapper.addHeader(USER_HEADER, userName);
chain.doFilter(requestWrapper, response);
}
@Override
public void destroy() {
}
public class HeaderMapRequestWrapper extends HttpServletRequestWrapper {
public HeaderMapRequestWrapper(HttpServletRequest request) {
super(request);
}
private Map<String, String> headerMap = new HashMap<String, String>();
public void addHeader(String name, String value) {
headerMap.put(name, value);
}
@Override
public String getHeader(String name) {
String headerValue = super.getHeader(name);
if (headerMap.containsKey(name)) {
headerValue = headerMap.get(name);
}
return headerValue;
}
@Override
public Enumeration<String> getHeaderNames() {
List<String> names = Collections.list(super.getHeaderNames());
for (String name : headerMap.keySet()) {
names.add(name);
}
return Collections.enumeration(names);
}
@Override
public Enumeration<String> getHeaders(String name) {
List<String> values = Collections.list(super.getHeaders(name));
if (headerMap.containsKey(name)) {
values.add(headerMap.get(name));
}
return Collections.enumeration(values);
}
}
}
2、Interceptor实现
@Component
public class ParamInterceptor extends HandlerInterceptorAdapter {
private static final Logger logger = LoggerFactory.getLogger(ParamInterceptor.class);
private static final String USER_PARAM = "userId";
private static final String USER_HEADER = "userName";
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (handler instanceof HandlerMethod) {
String userId = request.getParameter(USER_PARAM);
if (StringUtils.isNotBlank(userId)) {
//依据你自己的业务通过userId获取userName
String userName = getUserNameById(userId);
reflectSetHeader(request, USER_HEADER, userName);
}
}
return true;
}
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
super.postHandle(request, response, handler, modelAndView);
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
super.afterCompletion(request, response, handler, ex);
}
private void reflectSetHeader(HttpServletRequest request, String key, String value){
Class<? extends HttpServletRequest> requestClass = request.getClass();
logger.info("request实现类={}", requestClass.getName());
try {
Field request1 = requestClass.getDeclaredField("request");
request1.setAccessible(true);
Object o = request1.get(request);
Field coyoteRequest = o.getClass().getDeclaredField("coyoteRequest");
coyoteRequest.setAccessible(true);
Object o1 = coyoteRequest.get(o);
Field headers = o1.getClass().getDeclaredField("headers");
headers.setAccessible(true);
MimeHeaders o2 = (MimeHeaders)headers.get(o1);
o2.removeHeader(key);
o2.addValue(key).setString(value);
} catch (Exception e) {
logger.info("reflect set header error {}", e);
}
}
}