事情是这样的,在做用户登录这个模块的时候,前端是前端人员在做,后端是我在做。写好之后准备联调,突然发现传过来的是加密串,这很正常,但我自己写的所有后续的方法里参数走的是HttpServletRequest,获取参数是从request的参数里拿,那么此时就拿不到了。当然可以拿到加密串之后先解密,然后参数直接用解密完设定好的参数即可,但我不想改写好并测试好的代码,那么就要在走进我的方法之前在HttpServletRequest的参数中定义好我要用的字段。
HttpServletRequestWrapper为HttpServletRequest做了拓展,所以我们继承它来操作
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
public class ParameterRequestWrapper extends HttpServletRequestWrapper {
private Map<String, String[]> params = new HashMap<>();
/**
* 必须要实现的构造方法
* @param request
*/
public ParameterRequestWrapper(HttpServletRequest request) {
super(request);
//将参数表,赋予给当前的Map以便于持有request中的参数
this.params.putAll(request.getParameterMap());
}
/**
* 重载构造方法
* @param request
* @param extendParams
*/
public ParameterRequestWrapper(HttpServletRequest request, Map<String, Object> extendParams) {
this(request);
//这里将扩展参数写入参数表
addAllParameters(extendParams);
}
/**
* 在获取所有的参数名,必须重写此方法,否则对象中参数值映射不上
* @return
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
@Override
public Enumeration<String> getParameterNames() {
return new Vector(params.keySet()).elements();
}
/**
* 重写getParameter方法
* @param name 参数名
* @return 返回参数值
*/
@Override
public String getParameter(String name) {
String[] values = params.get(name);
if (values == null || values.length == 0) {
return null;
}
return values[0];
}
@Override
public String[] getParameterValues(String name) {
String[] values = params.get(name);
if (values == null || values.length == 0) {
return null;
}
return values;
}
/**
* 增加多个参数
* @param otherParams 增加的多个参数
*/
public void addAllParameters(Map<String, Object> otherParams) {
for (Map.Entry<String, Object> entry : otherParams.entrySet()) {
addParameter(entry.getKey(), entry.getValue());
}
}
/**
* 增加参数
* getParameterMap()中的类型是<String,String[]>类型的,所以这里要将其value转为String[]类型
* @param name 参数名
* @param value 参数值
*/
public void addParameter(String name, Object value) {
if (value != null) {
if (value instanceof String[]) {
params.put(name, (String[]) value);
} else if (value instanceof String) {
params.put(name, new String[]{(String) value});
} else {
params.put(name, new String[]{String.valueOf(value)});
}
}
}
}
拓展实现好之后,接下来就是对登录请求进行拦截,用过滤器拦截请求
@Component
public class RequestParameterFilter extends OncePerRequestFilter {
private static final String LOGIN_PATH = "/login/login";
private static final int DEFAULT_BUFFER_SIZE = 1024 * 4;
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if(request.getRequestURI().indexOf(LOGIN_PATH) != -1) {
Map<String, Object> paramter = new HashMap<String, Object>();
//前端传过来的值在request的body中 json串{"loginInfoStr":"xxxxxxxx"}
/**
* 获取HttpServletRequst中body的值第一种方式
* 使用alibaba提供的工具类com.alibaba.dubbo.common.utils.IOUtils
* BufferedReader reader = new BufferedReader(new InputStreamReader(request.getInputStream()));
* String body = IOUtils.read(reader);
*/
/**
* 获取HttpServletRequst中body的值第二种方式
* 使用IO字符流处理
*/
InputStream inputStream = request.getInputStream();
String jsonStr = getRequestBodyJson(inputStream);
JSONObject jsonObject = JSONObject.parseObject(jsonStr);
String loginInfoStr = jsonObject.get("loginInfoStr").toString();
/**
* 短信验证码SMS,其他验证码加相应判断,解密之后添加进paramter即可,解密方式看前端是如何加密的
*/
String smsCode = getValidateCode(Base64.decode(loginInfoStr), ValidateCodeType.SMS.getParamNameOnValidate());
if(!"".equals(smsCode)) {
paramter.put(ValidateCodeType.SMS.getParamNameOnValidate(), smsCode);
}
ParameterRequestWrapper wrapper = new ParameterRequestWrapper(request, paramter);
filterChain.doFilter(wrapper, response);
return;
}
filterChain.doFilter(request, response);
}
/**
* 得到HTTP请求body中的JSON字符串
* @param inputStream
* @return
* @throws IOException
*/
private String getRequestBodyJson(InputStream inputStream) throws IOException {
Reader input = new InputStreamReader(inputStream);
Writer output = new StringWriter();
char[] buffer = new char[DEFAULT_BUFFER_SIZE];
int n = 0;
while(-1 != (n = input.read(buffer))) {
output.write(buffer, 0, n);
}
return output.toString();
}
private String getValidateCode(String str, String codeName) {
String code = "";
Pattern pattern = Pattern.compile("\""+ codeName +"\":\"(\\S+?)\"");
Matcher matcher = pattern.matcher(str);
while (matcher.find()) {
code = matcher.group(1);
}
return code;
}
}
枚举类
public enum ValidateCodeType {
/**
* 短信验证码
*/
SMS {
@Override
public String getParamNameOnValidate() {
return SmsCodeConstant.PARAMETER_NAME_CODE_SMS;//"smsCode"
}
};
/**
* 校验时从请求中获取的参数的名字
* @return
*/
public abstract String getParamNameOnValidate();
}