通过Filter获取Http参数后,使用拦截器HandlerInterceptor读取参数进行相关操作

**

接到一种需求,要求针对所有会传userId的接口进行拦截并验证所传userId是否是当前登录用户的userId(登录用户的userId可从token中解析出来)

	/**
	针对该场景最先想到是使用拦截器HandlerInterceptor进行拦截,
	获取HttpServletRequest中的请求参数和token然后进行校验。
	测试之后发现是可以获取到参数的,但是如果通过流读取了request中的参数,
	会导致请求到controller层后参数为空,不满足需求
	*/
	public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse,
                             Object object) throws Exception {
        String token = httpServletRequest.getHeader("token");// 从http请求头中取出
        if (StringUtils.isBlank(token)) {
            returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
            return false;
        }
        String userId = getBody(httpServletRequest);
        if (StringUtils.isNotBlank(userId )){
            // 获取 token 中的 user id
            Long tokenUserId = JWTutil.getUserId(token);
            if (userId == null || "".equals(userId)){
                return true;
            }else{
                if (!userId.equals(tokenUserId.toString())){
                    returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
                    return false;
                }
            }
        }
        return true;
    }

	public static String getBody(HttpServletRequest request) throws IOException {
        String body = null;
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException ex) {
                    throw ex;
                }
            }
        }
        body = stringBuilder.toString();
        JsonParser jp = new JsonParser();
        String username="";
        try {
            JsonObject jo = jp.parse(body).getAsJsonObject();
            //注意这里会报异常的情况
            username = jo.get("userId").getAsString();
        }catch (Exception e){

        }finally {
            return username;
        }
    }

通过学习了解到可以使用Filter来过滤修改HttpRequest中的参数

@Component
//这里也可以使用@WebFilter 具体有什么区别我目前也不清楚
public class BodyWrapperFilter implements Filter {  
    @Override  
    public void destroy() {  
  
    }  
  	
  	/**
  	通过过滤器来拦截出post请求json格式的方法的http请求,然后通过
  	BodyReaderHttpServletRequestWrapper来生成复制的request提供给后边的
  	HandlerInterceptor来读取参数
  	*/
    @Override  
    public void doFilter(ServletRequest request, ServletResponse response,  
            FilterChain chain) throws IOException, ServletException {  
        ServletRequest requestWrapper = null;
        if (request instanceof HttpServletRequest) {  
            HttpServletRequest httpServletRequest = (HttpServletRequest) request;  
            if (StringUtils.equalsIgnoreCase(HttpMethod.POST.name(),httpServletRequest.getMethod())) {
                if (StringUtils.containsIgnoreCase(request.getContentType(), MediaType.APPLICATION_JSON_VALUE)) {
                requestWrapper = new BodyReaderHttpServletRequestWrapper(  
                        (HttpServletRequest) request);
                }
            }
           
        }  
        if (requestWrapper == null) {  
            chain.doFilter(request, response);  
        } else {  
            chain.doFilter(requestWrapper, response);   
        }  
    }  
  
    @Override  
    public void init(FilterConfig arg0) throws ServletException {  
  
    }  
  
}  

public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper {
	 private Map<String, String[]> paramsMap;

	    @Override
	    public Map getParameterMap() {
	        return paramsMap;
	    }

	    @Override
	    public String getParameter(String name) {// 重写getParameter,代表参数从当前类中的map获取
	        String[] values = paramsMap.get(name);
	        if (values == null || values.length == 0) {
	            return null;
	        }
	        return values[0];
	    }

	    @Override
	    public String[] getParameterValues(String name) {// 同上
	        return paramsMap.get(name);
	    }

	    @Override
	    public Enumeration getParameterNames() {
	        return Collections.enumeration(paramsMap.keySet());
	    }

	    private String getRequestBody(InputStream stream) {
	        String line = "";
	        StringBuilder body = new StringBuilder();
	        int counter = 0;

	        // 读取POST提交的数据内容
	        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
	        try {
	            while ((line = reader.readLine()) != null) {
	                if (counter > 0) {
	                    body.append("rn");
	                }
	                body.append(line);
	                counter++;
	            }
	        } catch (IOException e) {
	            e.printStackTrace();
	        }

	        return body.toString();
	    }

	    private HashMap<String, String[]> getParamMapFromPost(HttpServletRequest request) {

	        String body = "";
	        try {
	            body = getRequestBody(request.getInputStream());
	        } catch (IOException e) {
	            e.printStackTrace();
	        }
	        HashMap<String, String[]> result = new HashMap<String, String[]>();

	        if (null == body || 0 == body.length()) {
	            return result;
	        }

	        return parseQueryString(body);
	    }

	    // 自定义解码函数
	    private String decodeValue(String value) {
	        if (value.contains("%u")) {
	            try {
					return URLDecoder.decode(value, "UTF-8");
				} catch (UnsupportedEncodingException e) {
					 return "";
				}
	        } else {
	            try {
	                return URLDecoder.decode(value, "UTF-8");
	            } catch (UnsupportedEncodingException e) {
	                return "";// 非UTF-8编码
	            }
	        }
	    }

	    public HashMap<String, String[]> parseQueryString(String s) {
	        String valArray[] = null;
	        if (s == null) {
	            throw new IllegalArgumentException();
	        }
	        HashMap<String, String[]> ht = new HashMap<String, String[]>();
	        StringTokenizer st = new StringTokenizer(s, "&");
	        while (st.hasMoreTokens()) {
	            String pair = (String) st.nextToken();
	            int pos = pair.indexOf('=');
	            if (pos == -1) {
	                continue;
	            }
	            String key = pair.substring(0, pos);
	            String val = pair.substring(pos + 1, pair.length());
	            if (ht.containsKey(key)) {
	                String oldVals[] = (String[]) ht.get(key);
	                valArray = new String[oldVals.length + 1];
	                for (int i = 0; i < oldVals.length; i++) {
	                    valArray[i] = oldVals[i];
	                }
	                valArray[oldVals.length] = decodeValue(val);
	            } else {
	                valArray = new String[1];
	                valArray[0] = decodeValue(val);
	            }
	            ht.put(key, valArray);
	        }
	        return ht;
	    }

	    private Map<String, String[]> getParamMapFromGet(HttpServletRequest request) {
	        return parseQueryString(request.getQueryString());
	    }

	    private final byte[] body; // 报文

	    /**
	     * @param request
	     * @throws IOException
	     */
	    /**
	     * @param request
	     * @throws IOException
	     */
	    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
	        super(request);
	        body = readBytes(request.getInputStream());

	        // 首先从POST中获取数据
	        if ("POST".equals(request.getMethod().toUpperCase())) {
	            paramsMap = getParamMapFromPost(this);
	        } else {
	            paramsMap = getParamMapFromGet(this);
	        }

	    }

	    @Override
	    public BufferedReader getReader() throws IOException {
	        return new BufferedReader(new InputStreamReader(getInputStream()));
	    }

	    @Override
	    public ServletInputStream getInputStream() throws IOException {
	        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
	        return new ServletInputStream() {

	            @Override
	            public int read() throws IOException {
	                return bais.read();
	            }

				@Override
				public boolean isFinished() {
					return false;
				}

				@Override
				public boolean isReady() {
					return false;
				}

				@Override
				public void setReadListener(ReadListener arg0) {

				}
	        };
	    }

	    private static byte[] readBytes(InputStream in) throws IOException {
	        BufferedInputStream bufin = new BufferedInputStream(in);
	        int buffSize = 1024;
	        ByteArrayOutputStream out = new ByteArrayOutputStream(buffSize);

	        byte[] temp = new byte[buffSize];
	        int size = 0;
	        while ((size = bufin.read(temp)) != -1) {
	            out.write(temp, 0, size);
	        }
	        bufin.close();

	        byte[] content = out.toByteArray();
	        return content;
	    }

}

HandlerInterceptor调整为下方代码

public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse,
                             Object object) throws Exception {
        String token = httpServletRequest.getHeader("token");// 从http请求头中取出
        if (StringUtils.isBlank(token)) {
            returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
            return false;
        }
        String input = IOUtils.readStreamAsString(httpServletRequest.getInputStream(), "UTF-8");
        if (input.contains("userId")){
            JsonParser jp = new JsonParser();
            String userId = null;
            try {
                JsonObject jo = jp.parse(input).getAsJsonObject();
                userId = jo.get("userId").getAsString();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            // 获取 token 中的 user id
            Long tokenUserId = JWTutil.getUserId(token);
            if (userId == null || "".equals(userId)){
                return true;
            }else{
                if (!userId.equals(tokenUserId.toString())){
                    returnResponse(httpServletRequest, httpServletResponse, "", "", 0);
                    return false;
                }
            }
        }
        return true;
    }

**

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值