springboot 防xss攻击的三种方式(SQL注入适用)

本文介绍了防止XSS攻击的三种方法:1) 使用拦截器自定义HttpServletRequestWrapper,过滤请求参数;2) 自定义消息转换器,修改JSON反序列化过程;3) 利用注解和反射,对方法参数进行转义。每种方法都有其优缺点,拦截器全局拦截但可能影响性能,消息转换器仅适用于JSON数据,注解+反射无侵入但效率较低。
摘要由CSDN通过智能技术生成


XSS攻击和SQL注入,原理就不多说了,主要记录一下3种方式来避免

拦截器

主要是继承HttpServletRequestWrapper,然后重写里面的方法实现,再加上实现Filter达到拦截效果。其实转码方式可以直接用HtmlUtils.htmlEscape,但是getInputStream里面返回会有“\t”,会被转义,导致json格式不对,所以只能自己写转义方法
,也可以不重写getInputStream,改为直接对json化工具动手,直接注入自定义的objectMapper,里面在设置String的反序列化方法,不过这样就所有的反序列化都会受到影响,不大建议

package cn.com.tcc.ofa.common.filters;

import lombok.extern.slf4j.Slf4j;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;

/**
 * @author: moshiyuan
 * @date: 2021/9/9 13:44
 * @description: 重写获取参数
 */
@Component
@Slf4j
public class XssHttpServletRequestWraper extends HttpServletRequestWrapper {
    public XssHttpServletRequestWraper() {
        super(null);
    }

    public XssHttpServletRequestWraper(HttpServletRequest httpservletrequest) {
        super(httpservletrequest);
    }

    /**
     * 过滤springmvc中的 @RequestParam 注解中的参数
     * @param s 参数
     * @return 过滤后的
     */
    @Override
    public String[] getParameterValues(String s) {
        String[] str = super.getParameterValues(s);
        if (str == null) {
            return null;
        }
        int i = str.length;
        String[] as1 = new String[i];
        for (int j = 0; j < i; j++) {
            as1[j] = cleanXSS(str[j]);
        }
        return as1;
    }

    /**
     * 过滤request.getParameter的参数
     * @param s 参数
     * @return 过滤后的
     */
    @Override
    public String getParameter(String s) {
        String s1 = super.getParameter(s);
        if (s1 == null) {
            return null;
        } else {
            return cleanXSS(s1);
        }
    }

    /**
     * 过滤请求体 json 格式的
     * @return InputStream
     * @throws IOException io异常
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(inputHandlers(super.getInputStream ()).getBytes ());
        return new ServletInputStream() {

            @Override
            public int read(){
                return byteArrayInputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) { }
        };
    }


    public   String inputHandlers(ServletInputStream servletInputStream){
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new InputStreamReader(servletInputStream, StandardCharsets.UTF_8));
            String line;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (servletInputStream != null) {
                try {
                    servletInputStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return cleanXSS(sb.toString ());
    }
    public String cleanXSS(String src) {
        src = src.replaceAll("<", "&lt;").replaceAll(">", "&gt;");
        return src;
    }
}
package cn.com.tcc.ofa.common.filters;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
 * @author: moshiyuan
 * @date: 2021/9/9 13:50
 * @description: XSS攻击过滤器
 */
@Slf4j
public class XssFilter implements Filter {
    // 忽略权限检查的url地址
    private final String[] excludeUrls = new String[]{
            "null"
    };

    @Override
    public void doFilter(ServletRequest arg0, ServletResponse arg1, FilterChain arg2)
            throws IOException, ServletException {

        HttpServletRequest req = (HttpServletRequest) arg0;
        HttpServletResponse response = (HttpServletResponse) arg1;

        String pathInfo = req.getPathInfo() == null ? "" : req.getPathInfo();
        //获取请求url的后两层
        String url = req.getServletPath() + pathInfo;
        //获取请求你ip后的全部路径
        String uri = req.getRequestURI();
        //注入xss过滤器实例
        XssHttpServletRequestWraper reqW = new XssHttpServletRequestWraper(req);

        //过滤掉不需要的Xss校验的地址
        for (String str : excludeUrls) {
            if (uri.indexOf(str) >= 0) {
                arg2.doFilter(arg0, response);
                return;
            }
        }
        //过滤
        arg2.doFilter(reqW, response);
    }
    @Override
    public void destroy() {
    }
    @Override
    public void init(FilterConfig filterconfig){
    }
}

自定义消息转码器

主要是继承AbstractJackson2HttpMessageConverter,自定义MediaType,再增加一个String反序列化方法,缺点呢前端也要改改,改改请求头的Content-Type,然后Controller指定consumes,防止前端绕过你指定的MediaType,不过这样也有问题,前端用了其他的Content-Type,请求为报错,好像还不会走全局的异常捕捉,veg 可能需要特许考虑一下

package cn.com.tcc.ofa.common.filters;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.AbstractHttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.http.converter.json.AbstractJackson2HttpMessageConverter;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.lang.Nullable;

import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.Charset;

/**
 * @author: moshiyuan
 * @date: 2021/9/9 14:59
 * @description: 防XSS攻击自定义消息转换器
 */
public class XssJsonTestConverter extends AbstractJackson2HttpMessageConverter {
    @Nullable
    private String jsonPrefix;


    public XssJsonTestConverter() {
        this(Jackson2ObjectMapperBuilder.json().build());
    }

    public XssJsonTestConverter(ObjectMapper objectMapper) {
        super(objectMapper, new MediaType("application", "moshiyuan"));
        //注册xss解析器
        SimpleModule xssModule = new SimpleModule("XssStringJsonSerializer");
        xssModule.addDeserializer(String.class,new XssJacksonDeserializer());
        objectMapper.registerModule(xssModule);
    }

    public void setJsonPrefix(String jsonPrefix) {
        this.jsonPrefix = jsonPrefix;
    }

    public void setPrefixJson(boolean prefixJson) {
        this.jsonPrefix = prefixJson ? ")]}', " : null;
    }

    @Override
    protected void writePrefix(JsonGenerator generator, Object object) throws IOException {
        if (this.jsonPrefix != null) {
            generator.writeRaw(this.jsonPrefix);
        }

    }
}

注册自定义的消息转换器,这里有一个坑,如果已经项目里面已经继承WebMvcConfigurationSupport,你再通过WebMvcConfigurer去注册消息转换器,是注册不上的,这里卡了好长时间,一直想不明白为什么注册不上

package cn.com.tcc.ofa.common.config;

import cn.com.tcc.ofa.common.filters.XssJsonTestConverter;
import cn.com.tcc.ofa.common.interceptor.BaseHandlerInterceptor;
import cn.com.tcc.ofa.common.utils.AssertUtil;
import cn.com.tcc.ofa.common.utils.SpringContextUtil;
import com.alibaba.cloud.seata.web.SeataHandlerInterceptor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import org.springframework.context.annotation.Bean;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.config.annotation.InterceptorRegistration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurationSupport;

import java.text.SimpleDateFormat;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;

/**
 * @author liuzhibin
 * Date: 2019/11/1
 */
@Component
public class WebConfigurer extends WebMvcConfigurationSupport {
    @Override
    public void extendMessageConverters(List<HttpMessageConverter<?>> converters) {
        converters.add(xssJsonTestConverter());
    }

    @Bean
    public XssJsonTestConverter xssJsonTestConverter() {
        return new XssJsonTestConverter();
    }
}

controller测试

package cn.com.tcc.ofa.admin.controller;

import cn.com.tcc.ofa.admin.enums.UserStatusType;
import cn.com.tcc.ofa.admin.model.dto.*;
import cn.com.tcc.ofa.admin.model.po.OfaOrgUserInfo;
import cn.com.tcc.ofa.admin.model.po.OfaOrganization;
import cn.com.tcc.ofa.admin.model.po.OfaUser;
import cn.com.tcc.ofa.admin.service.IOfaOrgUserInfoService;
import cn.com.tcc.ofa.admin.service.IOfaOrganizationService;
import cn.com.tcc.ofa.admin.service.IOfaUserService;
import cn.com.tcc.ofa.admin.utils.LoginSource;
import cn.com.tcc.ofa.admin.utils.MobilePhoneValidCodeUtil;
import cn.com.tcc.ofa.common.annotaions.XssEscape;
import cn.com.tcc.ofa.common.controller.BaseController;
import cn.com.tcc.ofa.common.exception.RestException;
import cn.com.tcc.ofa.common.model.vo.CaptchaCheckResult;
import cn.com.tcc.ofa.common.model.vo.RestData;
import cn.com.tcc.ofa.common.utils.*;
import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.google.common.collect.Maps;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.vertx.core.Vertx;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;

import javax.annotation.Resource;
import javax.validation.Valid;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

/**
 * @author liuzhibin
 * Date: 2019/10/25
 */
@Api(tags = "登录控制器")
@RestController
@Slf4j
public class TestController extends BaseController {

    @PostMapping("/open-api/test/moshiyan")
    @ApiOperation(value = "测试", notes = "接口返回code非200时,需要重新获取验证码,此方法只适用旧账号登录")
    @XssEscape
    public RestData test(@RequestBody @Valid TestDto dto, BindingResult bindingResult) {
        if (bindingResult.hasErrors()) {
            return formErrorValid(bindingResult);
        }
        log.info(JSON.toJSONString(dto));
        return addRestData();
    }

    @PostMapping("/open-api/test/moshiyan1")
    @ApiOperation(value = "测试", notes = "接口返回code非200时,需要重新获取验证码,此方法只适用旧账号登录")
    @XssEscape
    public RestData test1(@RequestBody @Valid List<TestDto> dto, BindingResult bindingResult) {
        if (bindingResult.hasErrors()) {
            return formErrorValid(bindingResult);
        }
        log.info(JSON.toJSONString(dto));
        return addRestData("&lt;alert&gt;moshiyuan&lt;/alert&gt;");
    }

    @PostMapping(value = "/open-api/test/moshiyan2",consumes="application/moshiyuan")
    @ApiOperation(value = "测试", notes = "接口返回code非200时,需要重新获取验证码,此方法只适用旧账号登录")
    public RestData test2(@RequestBody @Valid TestDto dto, BindingResult bindingResult) {
        if (bindingResult.hasErrors()) {
            return formErrorValid(bindingResult);
        }
        log.info(JSON.toJSONString(dto));
        return addRestData();
    }
}

注解+反射

主要是注解,通过切面,获取方法的出入参,通过反射,对里面的String字段,进行转码,最麻烦的是出入参的不固定,相比通过消息转换器,改改反序列化String的方法就好了,这里就要考虑出入参的结构问题,最后是用递归调用,来实现的

package cn.com.tcc.ofa.common.annotaions;

import java.lang.annotation.*;

/**
 * @author: moshiyuan
 * @date: 2021/9/7 15:07
 * @description: Xss转义
 */
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Target(ElementType.METHOD)
public @interface XssEscape {
    /**
     * 是否入参转码
     */
    boolean escape() default true;
    /**
     * 是否出参解码
     */
    boolean unescape() default false;
}

package cn.com.tcc.ofa.common.aop;

import cn.com.tcc.ofa.common.annotaions.XssEscape;
import com.alibaba.druid.sql.visitor.functions.Char;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import org.springframework.validation.BeanPropertyBindingResult;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.util.HtmlUtils;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.*;

/**
 * @author: moshiyuan
 * @date: 2021/9/7 15:09
 * @description: Xss转义
 */
@Component
@Aspect
@Slf4j
public class XssEscapeAspect {
    @Pointcut("@annotation(cn.com.tcc.ofa.common.annotaions.XssEscape)")
    public void controllerPointcut() {}

    @Around("controllerPointcut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Method m = methodSignature.getMethod();
        XssEscape xssEscape = m.getAnnotation(XssEscape.class);
        if(xssEscape == null){
            return joinPoint.proceed();
        }
        Object[] args = joinPoint.getArgs();
        if(xssEscape.escape()){
            if( args!= null){
                for(int i=0;i< args.length;i++){
                    Object arg=args[i];
                    if(!(arg instanceof HttpServletRequest) && !(arg instanceof MultipartFile) && !(arg instanceof BeanPropertyBindingResult)) {
                        args[i]=escapeByReflect(arg,true);
                    }
                }
            }
        }
        Object retVal = joinPoint.proceed(args);
        if(xssEscape.unescape()){
            return escapeByReflect(retVal,false);
        }else{
            return retVal;
        }
    }

    /**
     * XSS转义,只对String类型转义
     * @param obj 对象
     * @param escape true转义 false转回去
     */
    private Object escapeByReflect(Object obj,boolean escape){
        Object rtnObj=null;
        try {
            if (obj ==null) {
                rtnObj = null;
            } else if (obj instanceof String) {
                rtnObj = escape((String) obj,escape);
            }else if (isIgnored(obj.getClass())) {
                rtnObj = obj;
            } else if (obj instanceof List) {
                List list=(List) obj;
                if(!list.isEmpty()){
                    for(int i=0;i< list.size();i++){
                        list.set(i,escapeByReflect(list.get(i),escape));
                    }
                }
                rtnObj=list;
            } else if (obj instanceof Set) {
                Set set=(Set) obj;
                if(!set.isEmpty()){
                    Set rtn=new HashSet(set.size());
                    set.forEach(key->rtn.add(escapeByReflect(key,escape)));
                    rtnObj=rtn;
                }else{
                    rtnObj=Collections.EMPTY_SET;
                }
            } else if (obj instanceof Map) {
                Map map=(Map) obj;
                if(!map.isEmpty()){
                    Map rtn=new HashMap(map.size());
                    map.forEach((key, value) -> rtn.put(key,escapeByReflect(value,escape)));
                    rtnObj=rtn;
                }else{
                    rtnObj=Collections.EMPTY_MAP;
                }
            } else {
                Field[] fields = getAllFields(obj);
                for(int i=0; i<fields.length; i++){
                    Field f = fields[i];
                    f.setAccessible(true);
                    if(String.class.isAssignableFrom(f.getType())){
                        String str= (String) f.get(obj);
                        String escapeStr=escape(str,escape);
                        f.set(obj,escapeStr);
                    }else if (isIgnored(f.getType())) {
                        //忽略的属性,不用处理
                    }else{
                        f.set(obj,escapeByReflect(f.get(obj),escape));
                    }
                }
                rtnObj=obj;
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return rtnObj;
    }

    /**
     * 获取所有的属性,包括父类
     * @param object 对象
     * @return 所有的属性
     */
    private Field[] getAllFields(Object object) {
        Class clazz = object.getClass();
        List<Field> fieldList = new ArrayList<>();
        while (clazz != Object.class) {
            fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
            clazz = clazz.getSuperclass();
        }
        Field[] fields = new Field[fieldList.size()];
        fieldList.toArray(fields);
        return fields;
    }

    private boolean isIgnored(Class<?> type) {
        return (type.isPrimitive()
                || Byte.class.isAssignableFrom(type)
                || Short.class.isAssignableFrom(type)
                || Integer.class.isAssignableFrom(type)
                || Long.class.isAssignableFrom(type)
                || Float.class.isAssignableFrom(type)
                || Double.class.isAssignableFrom(type)
                || Boolean.class.isAssignableFrom(type)
                || Char.class.isAssignableFrom(type)
                || Date.class.isAssignableFrom(type)
                || LocalDate.class.isAssignableFrom(type)
                || LocalTime.class.isAssignableFrom(type)
                || LocalDateTime.class.isAssignableFrom(type)
                || Enum.class.isAssignableFrom(type));
    }

    /**
     * 转义
     * @param str 转码字符串
     * @param escape true转义 false转回去
     * @return 转义
     */
    private String escape(String str,boolean escape){
        if(escape){
            return HtmlUtils.htmlEscape(str);
        }else{
            return HtmlUtils.htmlUnescape(str);
        }
    }
}

比较

第一种呢比较常规,很多博客都是写的第一种,然后加上改一下fastjson对String的反序列化方法,没啥优点,适合全局拦截转码。
第二种呢,比较灵活,但是只能针对json过来的数据,如果要对get请求,或者直接对request取值的,搞不了,还要结合第一种方式,那样又比较麻烦了,不过需求针对特定接口的话,第二种不错,不过需要前后端都要改一下
第三种,针对指定接口,会比较好,不需要改造,只需要再方法上面加一个注解就可以,无侵入,缺点就是反射影响性能。
其实还有第四种,就是注解,然后新建一个json工具类,里面对objectMapper增加一个String的反序列化方法,就是第3种里面写的那样,然后直接序列化再反序列化,搞定,不用考虑太多,不过一样影响性能,具体反射和序列化那个更耗时一点,没测,后面有时间可以测测看

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值