Filter - 自定义过滤器做数据拦截校验

一、 过滤器与拦截器的区别

过滤器:可以修改request,需要在servlet容器中实现,只能在方法请求前后使用。用于筛选。
拦截器:不能修改request,可以调用IOC容器中的各种依赖,可以详细到每个方法。用于终止流程。

1.配置需要过滤的接口

在resource文件夹下,创建一个存放方法名/路径的yml文件。eg:method.yml

// 路径名示例
method:
  map: {
     "[/base/data/getList]": '2,3',
     "[/base/data/getPage]": '2,3'
  }

// 方法名示例
method:
  map: {
     getList: '2,3',
     getPage: '2,3'
  }

获得配置文件中的map值

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import java.util.Map;

@Component
@Data
@ConfigurationProperties(prefix = "method")
public class MethodConfig {

    Map<String, String> map;
}

在启动类配置读取配置文件

//yaml属性源PropertySourcesPlaceholderConfigurer对象
@Bean
public static PropertySourcesPlaceholderConfigurer properties() {
    PropertySourcesPlaceholderConfigurer configurer = new
            PropertySourcesPlaceholderConfigurer();
    YamlPropertiesFactoryBean yaml = new YamlPropertiesFactoryBean();
    yaml.setResources(new ClassPathResource("method.yml"));
    configurer.setProperties(yaml.getObject());
    return configurer;
}

2.自定义Filter

import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.platform.commons.util.StringUtils;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.annotation.Resource;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;

@Slf4j
@Component
@AllArgsConstructor
public class VerifyFilter extends OncePerRequestFilter {
	
	// 自定义异常抛出路径
    private static final String  errorFilterPath = "/filter/errorFilter";

    private static final String  errorCode = "errorCode";

	// 请求类型
    private static final String METHOD_GET = "GET";

    private static final String METHOD_POST = "POST";

	// 从yml文件中读取map	
    @Resource
    private final MethodConfig methodConfig;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

		// 获得接口路径 /base/data/getList
        String requestURI = request.getRequestURI();

        //先判断该路径是否需要数据校验
        // 从配置文件中读取接口的数据权限,没有配置说明不需要数据权限
        Map<String, String> map = methodConfig.getmap();
        String str = map.get(requestURI);

        // 查询该接口所关联的角色,判断用户是否属于其中,如果该接口未绑定角色,说明不需要进行数据校验,直接跳回
        if (StringUtils.isBlank(str)) {
            filterChain.doFilter(request, response);
            return;
        }
        List<String> orgTypeList = Arrays.asList(orgTypeStr.split(","));

        log.info("进入数据校验拦截器, 接口路径 = " + request.getRequestURI());

		// 从token中获得用户信息
       User user = getUserByToken();
       
        String orgType = user.getOrgType();
        if (!orgTypeList.contains(orgType)) {
            // 抛异常,无权限
            request.setAttribute(errorCode, "数据校验未通过");
            
            // 重定向至抛出异常的接口
            request.getRequestDispatcher(errorFilterPath).forward(request, response);
            return;
        }

        // 如果用户是 valid类型,则需要校验code参数
        if ("valid".equals(orgType)) {
            Integer code = lusEntity.getCode();
            HttpServletRequest req = (HttpServletRequest)request;

            String codeParam= null;
            String method = req.getMethod();
            if (METHOD_POST.equals(method)) {
                // post 请求,获得Json格式的参数
				// 自定义wrapper,过滤、读取request请求参数
                XssHttpServletRequestWrapper requestWrapper = new XssHttpServletRequestWrapper((HttpServletRequest) request);
                Map<String, Object> bodyMap = requestWrapper.getBodyMap(requestWrapper.getInputStream());
                if (bodyMap != null && bodyMap.get("code") != null) {
                    codeParam = bodyMap.get("code").toString();
                }
            } else {
                // get 请求,直接取出参数
                codeParam = request.getParameter("code");
            }

            if (!code.toString().equals(codeParam)) {
                // 抛异常
                request.setAttribute(errorCode, "数据校验未通过");
                
                  // 重定向至抛出异常的接口
                request.getRequestDispatcher(errorFilterPath).forward(request, response);
                return;
            }
        }

        filterChain.doFilter(request, response);
    }
}

3.自定义wrapper

HttpServletRequestWrapper
当使用filter时,会出现需要获取或者改变HttpServletRequest对象的参数的情况。但是java.util.Map包装的HttpServletRequest对象的参数是不可改变的。我们不能改变对象本身,但是可以通过装饰模式来改变其状态。
HttpServletRequestWrapper类是HttpServletRequest类的装饰类。想要改变在httpServletRequest中的参数,可以通过httpServletRequest的装饰类HttpServletRequestWrapper来实现,只需要在装饰类中按照需要重写其getParameter(getParameterValues)方法即可。

@RequestBody
以流的形式读取request中的Json数据。读取流时,getReader()和getInputStream()只能调用一次。因为读取一次,标记一次当前的位置。第二次读取就从标记位置继续读取,所以会读不到数据。

获得参数
重写HttpServletRequestWrapper,可以把request保存下来。再通过过滤器,把保存下来的request填充进去,可以实现多次读取request。

XssHttpServletRequestWrapper

import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONObject;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
    /**
     * 没被包装过的HttpServletRequest(特殊场景,需要自己过滤)
     */
    HttpServletRequest orgRequest;

    /**
     * 方便重复读取requestBody,由于每次都会在filter创建,因此整个生命周期中是单例的,注意线程安全问题
     */
    private byte[] requestBody;

    /**
     * html过滤
     */
    private final static HTMLFilter htmlFilter = new HTMLFilter();

    public XssHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        orgRequest = request;

        // 获取Json参数体
        InputStream is = request.getInputStream();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        byte buff[] = new byte[1024];
        int read;
        while ((read = is.read(buff)) > 0) {
            baos.write(buff, 0, read);
        }
        requestBody = baos.toByteArray();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        String json = null;
        //保存内容。方便多次获取requestBody
        if (null == this.requestBody) {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            IoUtil.copy(super.getInputStream(), baos);
            this.requestBody = baos.toByteArray();
        }
        //为空,直接返回
        json = new String(this.requestBody, StandardCharsets.UTF_8);
        json = xssEncode(json);
        if (StrUtil.isBlank(json)) {
            return super.getInputStream();
        }

        final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes("utf-8"));
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return true;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }

            @Override
            public int read() throws IOException {
                return bis.read();
            }
        };
    }

    @Override
    public String getParameter(String name) {
        String value = super.getParameter(xssEncode(name));
        if (StrUtil.isNotBlank(value)) {
            value = xssEncode(value);
        }
        return value;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] parameters = super.getParameterValues(name);
        if (parameters == null || parameters.length == 0) {
            return null;
        }

        for (int i = 0; i < parameters.length; i++) {
            parameters[i] = xssEncode(parameters[i]);
        }
        return parameters;
    }

    @Override
    public Map<String,String[]> getParameterMap() {
        Map<String,String[]> map = new LinkedHashMap<>();
        Map<String,String[]> parameters = super.getParameterMap();
        for (String key : parameters.keySet()) {
            String[] values = parameters.get(key);
            for (int i = 0; i < values.length; i++) {
                values[i] = xssEncode(values[i]);
            }
            map.put(key, values);
        }
        return map;
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(xssEncode(name));
        if (StrUtil.isNotBlank(value)) {
            value = xssEncode(value);
        }
        return value;
    }

    private String xssEncode(String input) {
        return htmlFilter.filter(input);
    }

    /**
     * 获取最原始的request
     */
    public HttpServletRequest getOrgRequest() {
        return orgRequest;
    }

    /**
     * 获取最原始的request
     */
    public static HttpServletRequest getOrgRequest(HttpServletRequest request) {
        if (request instanceof XssHttpServletRequestWrapper) {
            return ((XssHttpServletRequestWrapper) request).getOrgRequest();
        }

        return request;
    }

    //获取request请求body中参数
    public Map<String,Object> getBodyMap(InputStream in) {
        String param= null;
        BufferedReader streamReader=null;
        try {
            streamReader = new BufferedReader( new InputStreamReader(in, "UTF-8"));
            StringBuilder responseStrBuilder = new StringBuilder();
            String inputStr;
            while ((inputStr = streamReader.readLine()) != null)
                responseStrBuilder.append(inputStr);
            if(!JsonUtil.getInstance().validate(responseStrBuilder.toString())){
                return new HashMap<String, Object>();
            }
            JSONObject jsonObject = JSONObject.parseObject(responseStrBuilder.toString());
            if(jsonObject==null){
                return new HashMap<String, Object>();
            }
            param = jsonObject.toJSONString();

        } catch (Exception e) {
            e.printStackTrace();
        }finally{
            if(streamReader!=null){
                try {
                    streamReader.close();
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
        return JSONObject.parseObject(param,Map.class);
    }

}

XssFilter
把保存下来的request填充并传递下去

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * XSS过滤
 */
public class XssFilter implements Filter {

	@Override
	public void init(FilterConfig config) throws ServletException {
	}

	@Override
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
		ServletRequest requestWrapper = null;
		if (request instanceof HttpServletRequest) {
			requestWrapper = new XssHttpServletRequestWrapper((HttpServletRequest) request);
		}
		if (requestWrapper == null) {
			chain.doFilter(request, response);
		} else {
			chain.doFilter(requestWrapper, response);
		}
	}

	@Override
	public void destroy() {
	}
}

4.校验字符串是否是合法的JSON格式

方法一:工具类

import java.text.CharacterIterator;
import java.text.StringCharacterIterator;

import org.apache.commons.lang3.StringUtils;

/**
 * 用于校验字符串是否是合法的JSON格式
 */
public class JsonUtil {
    private CharacterIterator it;
    private char c;
    private int col;
    private static JsonUtil instance;

    /**
     * 获取类的实例
     *
     * @return 类的实例
     */
    public static JsonUtil getInstance() {
        if (instance == null) {
            instance = new JsonUtil();
        }
        return instance;
    }

    /**
     * 验证一个字符串是否是合法的JSON串
     *
     * @param input 要验证的字符串
     * @return true-合法 ,false-非法
     */
    public boolean validate(String input) {
        if(StringUtils.isBlank(input)){
            return false;
        }
        input = input.trim();
        boolean ret = valid(input);
        return ret;
    }

    private boolean valid(String input) {
        if ("".equals(input)) {
            return false;
        }

        boolean ret = true;
        it = new StringCharacterIterator(input);
        c = it.first();
        col = 1;
        if (!value()) {
            ret = error("value", 1);
        } else {
            skipWhiteSpace();
            if (c != CharacterIterator.DONE) {
                ret = error("end", col);
            }
        }

        return ret;
    }

    private boolean value() {
        return literal("true") || literal("false") || literal("null") || string() || number() || object() || array();
    }

    private boolean literal(String text) {
        CharacterIterator ci = new StringCharacterIterator(text);
        char t = ci.first();
        if (c != t) return false;

        int start = col;
        boolean ret = true;
        for (t = ci.next(); t != CharacterIterator.DONE; t = ci.next()) {
            if (t != nextCharacter()) {
                ret = false;
                break;
            }
        }
        nextCharacter();
        if (!ret) error("literal " + text, start);
        return ret;
    }

    private boolean array() {
        return aggregate('[', ']', false);
    }

    private boolean object() {
        return aggregate('{', '}', true);
    }

    private boolean aggregate(char entryCharacter, char exitCharacter, boolean prefix) {
        if (c != entryCharacter) return false;
        nextCharacter();
        skipWhiteSpace();
        if (c == exitCharacter) {
            nextCharacter();
            return true;
        }

        for (; ; ) {
            if (prefix) {
                int start = col;
                if (!string()) return error("string", start);
                skipWhiteSpace();
                if (c != ':') return error("colon", col);
                nextCharacter();
                skipWhiteSpace();
            }
            if (value()) {
                skipWhiteSpace();
                if (c == ',') {
                    nextCharacter();
                } else if (c == exitCharacter) {
                    break;
                } else {
                    return error("comma or " + exitCharacter, col);
                }
            } else {
                return error("value", col);
            }
            skipWhiteSpace();
        }

        nextCharacter();
        return true;
    }

    private boolean number() {
        if (!Character.isDigit(c) && c != '-') return false;
        int start = col;
        if (c == '-') nextCharacter();
        if (c == '0') {
            nextCharacter();
        } else if (Character.isDigit(c)) {
            while (Character.isDigit(c))
                nextCharacter();
        } else {
            return error("number", start);
        }
        if (c == '.') {
            nextCharacter();
            if (Character.isDigit(c)) {
                while (Character.isDigit(c))
                    nextCharacter();
            } else {
                return error("number", start);
            }
        }
        if (c == 'e' || c == 'E') {
            nextCharacter();
            if (c == '+' || c == '-') {
                nextCharacter();
            }
            if (Character.isDigit(c)) {
                while (Character.isDigit(c))
                    nextCharacter();
            } else {
                return error("number", start);
            }
        }
        return true;
    }

    private boolean string() {
        if (c != '"') return false;

        int start = col;
        boolean escaped = false;
        for (nextCharacter(); c != CharacterIterator.DONE; nextCharacter()) {
            if (!escaped && c == '\\') {
                escaped = true;
            } else if (escaped) {
                if (!escape()) {
                    return false;
                }
                escaped = false;
            } else if (c == '"') {
                nextCharacter();
                return true;
            }
        }
        return error("quoted string", start);
    }

    private boolean escape() {
        int start = col - 1;
        if (" \\\"/bfnrtu".indexOf(c) < 0) {
            return error("escape sequence  \\\",\\\\,\\/,\\b,\\f,\\n,\\r,\\t  or  \\uxxxx ", start);
        }
        if (c == 'u') {
            if (!ishex(nextCharacter()) || !ishex(nextCharacter()) || !ishex(nextCharacter())
                    || !ishex(nextCharacter())) {
                return error("unicode escape sequence  \\uxxxx ", start);
            }
        }
        return true;
    }

    private boolean ishex(char d) {
        return "0123456789abcdefABCDEF".indexOf(c) >= 0;
    }

    private char nextCharacter() {
        c = it.next();
        ++col;
        return c;
    }

    private void skipWhiteSpace() {
        while (Character.isWhitespace(c)) {
            nextCharacter();
        }
    }

    private boolean error(String type, int col) {
        return false;
    }
}

方法二:捕获异常

import cn.hutool.json.JSONUtil;

Object param;
JSONObject jsonObject = new JSONObject();
try {
	 // 将参数转为jsonObject
    jsonObject = JSONUtil.parseObj(param);
    
} catch (Exception e) {
    // 不是json格式,额外处理。
}

5.抛出filter自定义异常

在filter过滤器中使用throw直接抛出自定义异常,会返回request内部的500错误码,而无法返回自定义的错误码和内容。
此处使用在filter中重定向至返回错误码的控制层接口,用来实现返回自定义异常。

import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import javax.servlet.http.HttpServletRequest;

@RequestMapping("/filter")
@RestController
public class FilterController {

    /**
     * 捕获filter中的异常并提示
     * @param request
     * @return
     */
    @RequestMapping("errorFilter")
    public RestResponse verifyUserFilter(HttpServletRequest request) {

        String code = request.getAttribute("errorCode").toString();
         return RestResponse.exception(code);
    }
}

参考文章:

用一个实例来说明HttpServletRequestWrapper类的使用
spring boot拦截器中获取request post请求中的参数
springMVC拦截器从Request中获取Json格式并解决request的请求流只能读取一次的问题

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值