springboot下Filter的POST和GET过滤参数

//定义一个filter过滤器



import org.apache.commons.lang.StringUtils;
import org.springframework.stereotype.Component;
import org.apache.commons.lang.StringEscapeUtils;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.Map;
import java.util.Set;

@Component
@WebFilter(filterName = "ValidatorFilter" , urlPatterns = "/*")
public class ValidatorFilter implements Filter {
    String[] strArr = {"\"","%","'"};
    @Override
    public void doFilter(ServletRequest request,
                         ServletResponse response,
                         FilterChain chain) throws IOException, ServletException{
     
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        String method  = (httpServletRequest.getMethod());
        Map<String, String[]> map = httpServletRequest.getParameterMap();
        ServletRequest requestWrapper = null;
        GetParameterRequestWrapper requestWrapper1= null;
        if(httpServletRequest.getMethod().equals("POST")){
             requestWrapper = new PostParameterRequestWrapper(httpServletRequest,method,map);
             chain.doFilter(requestWrapper, response);
        }else if(httpServletRequest.getMethod().equals("GET")){
            requestWrapper1 = new GetParameterRequestWrapper((HttpServletRequest)request);
            Set<String> key = map.keySet();
            for(String arr :strArr){
                for(String k : key){
                    String[] arrValues =  map.get(k);
                    String newValues= StringUtils.join(arrValues);
                    if(newValues.contains(arr)){
                        //对不合法参数转义
                        String escape = StringEscapeUtils.escapeXml(arr);
                        String s1 = newValues.replace(arr,escape);
                        //重新put相同的key,替换对应的values
                        requestWrapper1.addParameter(k, new String[]{s1});
                    }
                }
            }
            chain.doFilter(requestWrapper1, response);
        }
    }


    @Override
    public void destroy() { }


    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }


}
//get方式,修改请求域中的参数值,拦截不合法的参数,进行转义



import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.*;

class GetParameterRequestWrapper extends HttpServletRequestWrapper {


    private Map<String , String[]> params = new HashMap<String, String[]>();


    @SuppressWarnings("unchecked")
    public GetParameterRequestWrapper(HttpServletRequest request) {
        super(request);
        this.params.putAll(request.getParameterMap());
    }


    public GetParameterRequestWrapper(HttpServletRequest request , Map<String , Object> extendParams) {
        this(request);
        addAllParameters(extendParams);
    }

    @Override
    public String getParameter(String name) {
        String[] values = params.get(name);
        if (values == null || values.length == 0) {
            return null;
        }
        return values[0];
    }


    public String[] getParameterValues(String name) {
        return params.get(name);
    }

    public void addAllParameters(Map<String , Object>otherParams) {
        for(Map.Entry<String , Object>entry : otherParams.entrySet()) {
            addParameter(entry.getKey() , entry.getValue());
        }
    }

    public void addParameter(String name , Object value) {
        if(value != null) {
            if(value instanceof String[]) {
                params.put(name , (String[])value);
            }else if(value instanceof String) {
                params.put(name , new String[] {(String)value});
            }else {
                params.put(name , new String[] {String.valueOf(value)});
            }
        }
    }
}
//post方式,修改请求域中的参数值,拦截不合法的参数,进行转义





import org.apache.commons.lang.StringEscapeUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class PostParameterRequestWrapper  extends HttpServletRequestWrapper {

    private  byte[] body;
    String[] strArr = {"\"","%","'"};

    public PostParameterRequestWrapper(HttpServletRequest request, String method, Map<String, String[]> newParams) throws IOException {
        super(request);

        //获取request域json类型参数
        String param = getBodyString(request);

        //拆分json,参数属性放一个List集合中
        List<String> shuxing = new ArrayList<String>();
        //拆分json,参数值放一个List集合中
        List<String> values = new ArrayList<String>();

        System.out.println("param  "+param);
        if(param!= null && !param.equals("")){
            String newParam = param.substring(1,param.length()-1);
            String[] arrParam = newParam.split(",");
            for(String arr : arrParam){
                String[] newArr =  arr.split(":");

                //属性
                String par = newArr[0].trim();
                if(par.contains("\"") && par.length()>2){
                    par = par.substring(1,par.length()-1);
                }
                shuxing.add(par);

                //值
                if(newArr.length>1){
                    String par1 = newArr[1].trim();
                    if(par1.contains("\"") && par1.length()>2){
                        par1 = par1.substring(1,par1.length()-1);
                    }
                    values.add(par1);
                }else{
                    values.add("");
                }
            }

            //对值里面的不合法参数转义
            for(int i = 0;i<shuxing.size();i++){
                for(String arr :strArr){
                    if(values.get(i).contains(arr)){
                        //对不合法参数values转义
                        String newValues = StringEscapeUtils.escapeXml(arr);
                        String s1 = values.get(i).replace(arr,newValues);
                        values.set(i,s1);
                    }
                }
            }
            StringBuffer bf =new StringBuffer();
            //重组json字符串
            for(int k = 0;k<shuxing.size();k++){
                if(k+1 != shuxing.size()){
                    bf.append("\""+shuxing.get(k)+"\""+":"+ "\""+ values.get(k)+"\""+",");
                }else{
                    bf.append("\""+shuxing.get(k)+"\""+":"+  "\""+values.get(k)+"\"");
                }
            }
            String sb = "{"+ bf.toString() +"}";
            System.out.println("sb "+sb);
            body = sb.getBytes(Charset.forName("UTF-8"));
        }
    }

    /**
     * 获取请求Body
     *
     * @param request
     * @return
     */
    public String getBodyString(final ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        System.out.println("sb.toString " +sb.toString());
        return sb.toString();
    }

    /**
     * Description: 复制输入流</br>
     *
     * @param inputStream
     * @return</br>
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }
    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {

        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }

}

 

好的,下面是一个完整的Spring Boot防止SQL注入的过滤器示例: ```java import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.regex.Matcher; import java.util.regex.Pattern; @Component @Order(1) public class SqlInjectionFilter extends OncePerRequestFilter { private static final String SQL_REGEX = "(?i)(\\b(select|update|delete|insert|create|drop|alter|truncate|grant|revoke|backup|restore)\\b)|(\\b(and|or)\\b.+?(=|>|<|>=|<=|<>|!=|!<|!>)|\\b(in|like|regexp|sounds)\\b.+?\\()"; @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { String method = request.getMethod(); if ("GET".equals(method)) { String queryString = request.getQueryString(); if (!StringUtils.isEmpty(queryString)) { String filteredQuery = filter(queryString); request = new FilteredGetRequest(request, filteredQuery); } } else if ("POST".equals(method)) { String contentType = request.getContentType(); if (contentType != null && contentType.contains("application/x-www-form-urlencoded")) { String body = HttpHelper.getBodyString(request); if (!StringUtils.isEmpty(body)) { String filteredBody = filter(body); request = new FilteredPostRequest(request, filteredBody); } } } filterChain.doFilter(request, response); } private static String filter(String input) { Pattern pattern = Pattern.compile(SQL_REGEX); Matcher matcher = pattern.matcher(input); String filteredInput = matcher.replaceAll(""); return filteredInput; } private static class FilteredGetRequest extends HttpServletRequestWrapper { private String filteredQuery; public FilteredGetRequest(HttpServletRequest request, String filteredQuery) { super(request); this.filteredQuery = filteredQuery; } @Override public String getQueryString() { return filteredQuery; } } private static class FilteredPostRequest extends HttpServletRequestWrapper { private String filteredBody; public FilteredPostRequest(HttpServletRequest request, String filteredBody) { super(request); this.filteredBody = filteredBody; } @Override public String getParameter(String name) { String value = super.getParameter(name); if (value != null) { return filter(value); } return null; } @Override public String getHeader(String name) { String value = super.getHeader(name); if (value != null) { return filter(value); } return null; } @Override public ServletInputStream getInputStream() throws IOException { return new FilteredServletInputStream(super.getInputStream(), filteredBody); } @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new FilteredStringReader(super.getReader(), filteredBody)); } private static class FilteredServletInputStream extends ServletInputStream { private InputStream inputStream; private String filteredBody; public FilteredServletInputStream(InputStream inputStream, String filteredBody) { this.inputStream = inputStream; this.filteredBody = filteredBody; } @Override public int read() throws IOException { return inputStream.read(); } @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } } private static class FilteredStringReader extends StringReader { private String filteredBody; public FilteredStringReader(Reader reader, String filteredBody) { super(filteredBody); this.filteredBody = filteredBody; } @Override public int read(char[] cbuf, int off, int len) throws IOException { return super.read(cbuf, off, len); } } } } ``` 在该过滤器中,首先判断请求的方法是GET还是POST,然后对请求参数进行过滤。 对于GET请求,可以通过`getQueryString`方法获取查询字符串并进行过滤,然后将过滤后的查询字符串封装到`FilteredGetRequest`类中,并将该封装后的请求对象传递给过滤器链中的下一个过滤器。 对于POST请求,需要判断请求的Content-Type是否为"application/x-www-form-urlencoded",如果是,则需要获取请求的Body并进行过滤,然后将过滤后的Body封装到`FilteredPostRequest`类中,并将该封装后的请求对象传递给过滤器链中的下一个过滤器。 在`filter`方法中,使用正则表达式对查询字符串或请求Body进行过滤,以防止SQL注入攻击。 需要注意的是,上述示例仅仅是一个简单的防止SQL注入的过滤器示例,实际情况下可能需要更加复杂的过滤逻辑。另外,在使用任何过滤器时,都需要仔细测试和验证,确保不会对系统产生不必要的影响。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值