提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
安全组提了一个要求,json中的字段,必须是实体中有的成员,因此使用拦截器拦截了所有请求并在其中进行了判断
一、java拦截器是什么?
用于拦截所有请求
二、使用步骤
1.配置拦截器
代码如下(示例):
import com.XXX.xx.config.interceptor.UserInfoInterceptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.CorsRegistry; import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @Configuration public class WebConfig implements WebMvcConfigurer { @Autowired private UserInfoInterceptor userInfoInterceptor ; private static final String[] PATHS = new String[]{ } ; @Override public void addInterceptors(InterceptorRegistry registry) { registry.addInterceptor(userInfoInterceptor) .addPathPatterns("/**") .excludePathPatterns(PATHS); } @Override public void addCorsMappings(CorsRegistry registry) { registry.addMapping("/**") .allowedOrigins("*") .allowCredentials(true) .allowedMethods("GET","POST","PUT") .maxAge(3600); } }
2.校验逻辑
代码如下(示例): 拦截方法
import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.toolkit.StringUtils; import com.xx.xx.common.CommonException; import com.xx.xx.config.wrapper.VerifyRequestWrapper; import com.xx.xx.enums.ErrorCodeEnum; import java.lang.reflect.Field; import java.util.Arrays; import java.util.Set; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.core.MethodParameter; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.method.HandlerMethod; import org.springframework.web.servlet.HandlerInterceptor; @Component public class VerifyInterceptor implements HandlerInterceptor { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { if ( handler instanceof HandlerMethod) { HandlerMethod handlerMethod = (HandlerMethod) handler; final MethodParameter[] methodParameters = handlerMethod.getMethodParameters(); if (isJson(request)) { // 获取json字符串 String jsonParam = new VerifyRequestWrapper(request).getBodyString(); boolean temp = true; for (MethodParameter parameterType : methodParameters) { RequestBody annotation = parameterType.getParameterAnnotation(RequestBody.class); if(annotation==null){ continue; } final Class<?> parameterTypeClass = parameterType.getParameterType(); Field[] fields = parameterTypeClass.getDeclaredFields(); //判断body入参是否和接收对象参数一致 temp = isConsistent(jsonParam,fields); } if(!temp){ //参数不一致(多余不存在的参数),参数非法 throw new CommonException(ErrorCodeEnum.COMMON_ERROR.getCode(),ErrorCodeEnum.COMMON_ERROR.getMsg()); } } } return true; } /** * * @param jsonParam 请求参数 * @param fields 实体对象字段 * @return */ private boolean isConsistent(String jsonParam, Field[] fields) { if(StringUtils.isNotBlank(jsonParam)){ try { JSONObject jsonObject = JSONObject.parseObject(jsonParam); if(jsonObject!=null){ Set<String> keySet = jsonObject.keySet(); for (String s : keySet) { long count = Arrays.stream(fields).filter(f -> f.getName().equals(s)).count(); if(count==0){ return false; } } } } catch (Exception e) { } } return true; } /** * 判断本次请求的数据类型是否为json */ private boolean isJson(HttpServletRequest request) { if (request.getContentType() != null) { return request.getContentType().contains("application/json"); } return false; } }
import lombok.extern.slf4j.Slf4j; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.*; import java.nio.charset.Charset; /** * 包装HttpServletRequest,目的是让其输入流可重复读 **/ @Slf4j public class VerifyRequestWrapper extends HttpServletRequestWrapper { /** * 存储body数据的容器 */ private final byte[] body; public VerifyRequestWrapper(HttpServletRequest request) throws IOException { super(request); // 将body数据存储起来 String bodyStr = getBodyString(request); body = bodyStr.getBytes(Charset.defaultCharset()); } /** * 获取请求Body */ public String getBodyString(final ServletRequest request) { try { return inputStream2String(request.getInputStream()); } catch (IOException e) { throw new RuntimeException(e); } } /** * 获取请求Body */ public String getBodyString() { final InputStream inputStream = new ByteArrayInputStream(body); return inputStream2String(inputStream); } /** * 将inputStream里的数据读取出来并转换成字符串 */ private String inputStream2String(InputStream inputStream) { StringBuilder sb = new StringBuilder(); BufferedReader reader = null; try { reader = new BufferedReader(new InputStreamReader(inputStream, Charset.defaultCharset())); String line; while ((line = reader.readLine()) != null) { sb.append(line); } } catch (IOException e) { throw new RuntimeException(e); } finally { if (reader != null) { try { reader.close(); } catch (IOException e) { log.error("读取数据流出错", e); } } } return sb.toString(); } @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(getInputStream())); } @Override public ServletInputStream getInputStream() throws IOException { final ByteArrayInputStream inputStream = new ByteArrayInputStream(body); return new ServletInputStream() { @Override public int read() throws IOException { return inputStream.read(); } @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } }; } }
总结
使用以上代码就可以在拦截器中利用反射进行校验了