拦截器中校验请求参数(漏洞修复之防止sql注入)

防止sql注入sql中我们常用# 但是有些时候难免要用到拼接sql,这时候我们在后端就要进行参数的判断,综合网上查找的情况,记录一下.
注册拦截器

import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

@Configuration
public class MyWedAppConfigurer implements WebMvcConfigurer {

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        // 多个拦截器组成一个拦截器链
        // addPathPatterns 用于添加拦截规则
        // excludePathPatterns 用户排除拦截
       registry.addInterceptor(new MyInterceptor()).addPathPatterns("/**");
        WebMvcConfigurer.super.addInterceptors(registry);
    }

}

第二个实现自己的拦截器

import com.alibaba.fastjson.JSONObject;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.*;
import java.util.regex.Pattern;


@Component
public class MyInterceptor implements HandlerInterceptor {

    private Pattern sqlPattern = Pattern.compile(
            "\\b(and|exec|insert|drop|grant|alter|delete|update|count|chr|mid|master|truncate|char|declare|or|exec|having|sleep)\\b|^\\.{2}/|\\s\\|{2}\\s|\\s\\+\\s");// |('|%)S

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {

     
        boolean fang = true;
//        String uri = request.getRequestURI();
        //获取访问者的IP
        String remoteAddr = request.getRemoteAddr();
        remoteAddr = getIpAddsr(request);
    
            if (request.getMethod().equalsIgnoreCase("post")) {
                fang = checkPost(request);
            }

            if (request.getMethod().equalsIgnoreCase("get")) {
                fang = checkGet(request);
            }

            if (!fang) {
                code = "400";
                message = "ParameterIllegal !";
            }
        }
        if (!fang) {
            PrintWriter writer = response.getWriter();
            writer.write("{\n" +
                    "    \"code\": " + code + ",\n" +
                    "    \"msg\": " + message + ",\n" +
                    "    \"data\": null\n" +
                    "}");
            writer.flush();
            writer.close();
        }
        return fang;
    }


    private boolean checkGet(HttpServletRequest request) {
        return !request.getParameterMap().values().stream().anyMatch(values -> {
            for (String v : values) {
                if (sqlPattern.matcher(v).find()) {
                    LOG.error(v + "  Illegal ");
                    return true;
                }
            }
            return false;
        });
    }

    private boolean checkPost(HttpServletRequest request) {
        try {
            Map reqMap = new HashMap<String, Object>();
            String paramString = "";
            MyRequestWrapper requestWrapper = new MyRequestWrapper(request);
            String bodyString = requestWrapper.getBodyString();
            if (StringUtil.isNotEmpty(bodyString)) {
                Map postMap = JSONObject.parseObject(bodyString, Map.class);
                paramString = JSONObject.toJSONString(reqMap);
                reqMap.putAll(postMap);
            }

            boolean containKey = sqlPattern.matcher(paramString).find();
            if (containKey) {
                LOG.error(paramString + " Illegal ");
            }
            return !containKey;
        } catch (Exception ex) {
            return true;
        }
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }
    }

第三个 将请求参数继续传下去

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;

public class MyRequestWrapper extends HttpServletRequestWrapper {
    private final byte[] body;
    /**
     * Constructs a request object wrapping the given request.
     *
     * @param request The request to wrap
     * @throws IllegalArgumentException if the request is null
     */
    public MyRequestWrapper(HttpServletRequest request) {
        super(request);
        String sessionStream = getBodyString(request);
        body = sessionStream.getBytes(Charset.forName("UTF-8"));
    }

    public String  getBodyString(){
        return new String(body,Charset.forName("UTF-8"));
    }

    /**
     * 获取请求Body
     *
     * @param request
     * @return
     */
    private String getBodyString(final ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }

    /**
     * Description: 复制输入流
     *
     * @param inputStream
     * @return
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }
    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {

        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            public boolean isFinished() {
                return false;
            }

            public boolean isReady() {
                return false;
            }

            public void setReadListener(ReadListener readListener) {

            }
        };
    }
}

第四个 过滤器对请求处理

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public class RequestFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {

        HttpServletResponse res = (HttpServletResponse) servletResponse;
        HttpServletRequest req = (HttpServletRequest) servletRequest;
        String origin = req.getHeader("Origin");
        if(!org.springframework.util.StringUtils.isEmpty(origin)) {
            res.addHeader("Access-Control-Allow-Origin", origin);
        }
        res.addHeader("Access-Control-Allow-Methods", "*");
        res.addHeader("Access-Control-Allow-Credentials", "true");
        res.addHeader("Access-Control-Allow-Headers", "*");
        if (req.getMethod().equals("OPTIONS")) {
            res.setStatus(HttpServletResponse.SC_OK);
            return;
        }
        String uri = ((HttpServletRequest) servletRequest).getMethod();
        if("post".equalsIgnoreCase(uri)) {
            //解决POST请求从stream只能获取一次数据问题
            MyRequestWrapper requestWrapper = new MyRequestWrapper((HttpServletRequest) servletRequest);
            filterChain.doFilter(requestWrapper, servletResponse);
        }else{
            filterChain.doFilter(servletRequest, servletResponse);
        }


    }

    /**
     * 请求uri是否在指定列表中
     * @param curUri
     * @return
     */
    private boolean containRegisterUri(String curUri){
        String[] arr = new String[]{
         
                "这里是需要判断的接口"
        };
        List<String> urls = Arrays.asList(arr);
        for(String url : urls){
            if(curUri.contains(url)){
                return true;
            }
        }
        return false;
    }

    @Override
    public void destroy() {

    }

}

第五个在启动类中添加过滤器

@Bean
    public FilterRegistrationBean httpServletRequestReplacedRegistration() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setFilter(new RequestFilter());
        registration.addUrlPatterns("/*");
        registration.addInitParameter("paramName", "paramValue");
        registration.setName("efmRequestFilter");
        registration.setOrder(1);
        return registration;
    }

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值