**
接到一种需求,要求针对所有会传userId的接口进行拦截并验证所传userId是否是当前登录用户的userId(登录用户的userId可从token中解析出来)
/**
针对该场景最先想到是使用拦截器HandlerInterceptor进行拦截,
获取HttpServletRequest中的请求参数和token然后进行校验。
测试之后发现是可以获取到参数的,但是如果通过流读取了request中的参数,
会导致请求到controller层后参数为空,不满足需求
*/
public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse,
Object object) throws Exception {
String token = httpServletRequest.getHeader("token");// 从http请求头中取出
if (StringUtils.isBlank(token)) {
returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
return false;
}
String userId = getBody(httpServletRequest);
if (StringUtils.isNotBlank(userId )){
// 获取 token 中的 user id
Long tokenUserId = JWTutil.getUserId(token);
if (userId == null || "".equals(userId)){
return true;
}else{
if (!userId.equals(tokenUserId.toString())){
returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
return false;
}
}
}
return true;
}
public static String getBody(HttpServletRequest request) throws IOException {
String body = null;
StringBuilder stringBuilder = new StringBuilder();
BufferedReader bufferedReader = null;
try {
InputStream inputStream = request.getInputStream();
if (inputStream != null) {
bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
char[] charBuffer = new char[128];
int bytesRead = -1;
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
stringBuilder.append(charBuffer, 0, bytesRead);
}
} else {
stringBuilder.append("");
}
} catch (IOException ex) {
throw ex;
} finally {
if (bufferedReader != null) {
try {
bufferedReader.close();
} catch (IOException ex) {
throw ex;
}
}
}
body = stringBuilder.toString();
JsonParser jp = new JsonParser();
String username="";
try {
JsonObject jo = jp.parse(body).getAsJsonObject();
//注意这里会报异常的情况
username = jo.get("userId").getAsString();
}catch (Exception e){
}finally {
return username;
}
}
通过学习了解到可以使用Filter来过滤修改HttpRequest中的参数
@Component
//这里也可以使用@WebFilter 具体有什么区别我目前也不清楚
public class BodyWrapperFilter implements Filter {
@Override
public void destroy() {
}
/**
通过过滤器来拦截出post请求json格式的方法的http请求,然后通过
BodyReaderHttpServletRequestWrapper来生成复制的request提供给后边的
HandlerInterceptor来读取参数
*/
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
ServletRequest requestWrapper = null;
if (request instanceof HttpServletRequest) {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
if (StringUtils.equalsIgnoreCase(HttpMethod.POST.name(),httpServletRequest.getMethod())) {
if (StringUtils.containsIgnoreCase(request.getContentType(), MediaType.APPLICATION_JSON_VALUE)) {
requestWrapper = new BodyReaderHttpServletRequestWrapper(
(HttpServletRequest) request);
}
}
}
if (requestWrapper == null) {
chain.doFilter(request, response);
} else {
chain.doFilter(requestWrapper, response);
}
}
@Override
public void init(FilterConfig arg0) throws ServletException {
}
}
public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper {
private Map<String, String[]> paramsMap;
@Override
public Map getParameterMap() {
return paramsMap;
}
@Override
public String getParameter(String name) {// 重写getParameter,代表参数从当前类中的map获取
String[] values = paramsMap.get(name);
if (values == null || values.length == 0) {
return null;
}
return values[0];
}
@Override
public String[] getParameterValues(String name) {// 同上
return paramsMap.get(name);
}
@Override
public Enumeration getParameterNames() {
return Collections.enumeration(paramsMap.keySet());
}
private String getRequestBody(InputStream stream) {
String line = "";
StringBuilder body = new StringBuilder();
int counter = 0;
// 读取POST提交的数据内容
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
try {
while ((line = reader.readLine()) != null) {
if (counter > 0) {
body.append("rn");
}
body.append(line);
counter++;
}
} catch (IOException e) {
e.printStackTrace();
}
return body.toString();
}
private HashMap<String, String[]> getParamMapFromPost(HttpServletRequest request) {
String body = "";
try {
body = getRequestBody(request.getInputStream());
} catch (IOException e) {
e.printStackTrace();
}
HashMap<String, String[]> result = new HashMap<String, String[]>();
if (null == body || 0 == body.length()) {
return result;
}
return parseQueryString(body);
}
// 自定义解码函数
private String decodeValue(String value) {
if (value.contains("%u")) {
try {
return URLDecoder.decode(value, "UTF-8");
} catch (UnsupportedEncodingException e) {
return "";
}
} else {
try {
return URLDecoder.decode(value, "UTF-8");
} catch (UnsupportedEncodingException e) {
return "";// 非UTF-8编码
}
}
}
public HashMap<String, String[]> parseQueryString(String s) {
String valArray[] = null;
if (s == null) {
throw new IllegalArgumentException();
}
HashMap<String, String[]> ht = new HashMap<String, String[]>();
StringTokenizer st = new StringTokenizer(s, "&");
while (st.hasMoreTokens()) {
String pair = (String) st.nextToken();
int pos = pair.indexOf('=');
if (pos == -1) {
continue;
}
String key = pair.substring(0, pos);
String val = pair.substring(pos + 1, pair.length());
if (ht.containsKey(key)) {
String oldVals[] = (String[]) ht.get(key);
valArray = new String[oldVals.length + 1];
for (int i = 0; i < oldVals.length; i++) {
valArray[i] = oldVals[i];
}
valArray[oldVals.length] = decodeValue(val);
} else {
valArray = new String[1];
valArray[0] = decodeValue(val);
}
ht.put(key, valArray);
}
return ht;
}
private Map<String, String[]> getParamMapFromGet(HttpServletRequest request) {
return parseQueryString(request.getQueryString());
}
private final byte[] body; // 报文
/**
* @param request
* @throws IOException
*/
/**
* @param request
* @throws IOException
*/
public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
body = readBytes(request.getInputStream());
// 首先从POST中获取数据
if ("POST".equals(request.getMethod().toUpperCase())) {
paramsMap = getParamMapFromPost(this);
} else {
paramsMap = getParamMapFromGet(this);
}
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bais.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener arg0) {
}
};
}
private static byte[] readBytes(InputStream in) throws IOException {
BufferedInputStream bufin = new BufferedInputStream(in);
int buffSize = 1024;
ByteArrayOutputStream out = new ByteArrayOutputStream(buffSize);
byte[] temp = new byte[buffSize];
int size = 0;
while ((size = bufin.read(temp)) != -1) {
out.write(temp, 0, size);
}
bufin.close();
byte[] content = out.toByteArray();
return content;
}
}
HandlerInterceptor调整为下方代码
public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse,
Object object) throws Exception {
String token = httpServletRequest.getHeader("token");// 从http请求头中取出
if (StringUtils.isBlank(token)) {
returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
return false;
}
String input = IOUtils.readStreamAsString(httpServletRequest.getInputStream(), "UTF-8");
if (input.contains("userId")){
JsonParser jp = new JsonParser();
String userId = null;
try {
JsonObject jo = jp.parse(input).getAsJsonObject();
userId = jo.get("userId").getAsString();
} catch (Exception e) {
throw new RuntimeException(e);
}
// 获取 token 中的 user id
Long tokenUserId = JWTutil.getUserId(token);
if (userId == null || "".equals(userId)){
return true;
}else{
if (!userId.equals(tokenUserId.toString())){
returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
return false;
}
}
}
return true;
}
**