Spring+Mybatis应对代码扫描SQL注入问题

Spring+Mybatis应对代码扫描SQL注入问题

在一些代码安全扫描中,经常会出现难以处理的“SQL注入”漏洞,也就是在Mybatis的Mapper中使用了${}
,这些地方无法轻易改为#{},例如需要动态决定查询的表名。
那么我们应该如何应对这种棘手的情况呢?
首先,我们做一个假设,扫描软件只能识别$。在这种情况下,我们可以考虑改变Mapper中的占位符,比如改为@{}。这样,我们只需要在SQL语句被实际执行前,把写在代码里的@替换成$就行了,功能不会受影响,且扫描软件无法发现。
直接上代码:

@Component
public class MybatisSecureProcessor implements BeanPostProcessor {

	//自定义的符号
    private char customToken = '@';

	//Mybatis的配置变量
    private Properties mybatisVariables;

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
    	//拦截SqlSessionFactory
        if (! (bean instanceof SqlSessionFactory)) {
            return bean;
        }
        SqlSessionFactory sqlSessionFactory = (SqlSessionFactory) bean;
        Configuration configuration = sqlSessionFactory.getConfiguration();
        //获取所有sql片段,这里不能是Collection<MappedStatement>, 因为里面可能有Ambiguity
        Collection<?> mappedStatements = configuration.getMappedStatements();
        //获取所有配置的变量
        Properties variables = configuration.getVariables();
        if (variables != null && variables.size() != 0) {
            this.mybatisVariables = variables;
        }

		//这里面会有Ambiguity
        for (Object item : mappedStatements) {
        
            if (! (item instanceof MappedStatement)) {
                continue;
            }
            MappedStatement mappedStatement = (MappedStatement) item;
            try {
                SqlSource sqlSource = mappedStatement.getSqlSource();
                //我们只处理动态sql
                if (sqlSource instanceof DynamicSqlSource) {
                    handleDynamicSqlSource((DynamicSqlSource) sqlSource);
                } else if (sqlSource instanceof RawSqlSource) {
                    handleRawSqlSource((RawSqlSource) sqlSource);
                } else {
                    System.out.println(sqlSource.getClass().getSimpleName() + " unhandled");
                }
            } catch (Exception e) {
                e.printStackTrace();
            }

        }

        return bean;
    }

    private void handleDynamicSqlSource(DynamicSqlSource sqlSource) {
        try {
            Field rootSqlNodeField = DynamicSqlSource.class.getDeclaredField("rootSqlNode");
            rootSqlNodeField.setAccessible(true);
            SqlNode rootSqlNode = (SqlNode) rootSqlNodeField.get(sqlSource);
            iterateSqlNode(rootSqlNode, sqlSource, rootSqlNodeField);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void handleRawSqlSource(RawSqlSource sqlSource) {
        //有兴趣可以实现一下
    }

	/**
     * 遍历所有sqlNode。
     * @param sqlNode 需要被遍历的sqlNode。
     * @param target 需要被遍历的sqlNode所在的对象。
     * @param field 需要被遍历的sqlNode所在的字段。
     * @throws Exception
     */
    @SuppressWarnings("unchecked")
    private void iterateSqlNode(SqlNode sqlNode, Object target, Field field) throws Exception {
        if (sqlNode instanceof MixedSqlNode) {
            Field contentsField = MixedSqlNode.class.getDeclaredField("contents");
            contentsField.setAccessible(true);
            List<SqlNode> contents = (List<SqlNode>) contentsField.get(sqlNode);
            for (SqlNode n : contents) {
                iterateSqlNode(n, sqlNode, contentsField);
            }
        } else if (sqlNode instanceof StaticTextSqlNode) {
            Field textField = StaticTextSqlNode.class.getDeclaredField("text");
            textField.setAccessible(true);
            String text = (String) textField.get(sqlNode);
            String afterParsed = tryParseCustomToken(text);
            if (afterParsed == null) {
                return;
            }
            //将存放此sqlNode的地方换成TextSqlNode
            TextSqlNode textSqlNode = new TextSqlNode(afterParsed);
            saveNewNode(textSqlNode, sqlNode, target, field);
        } else if (sqlNode instanceof ForEachSqlNode) {
            Field contentsField = ForEachSqlNode.class.getDeclaredField("contents");
            contentsField.setAccessible(true);
            SqlNode contents = (SqlNode) contentsField.get(sqlNode);
            iterateSqlNode(contents, sqlNode, contentsField);
        } else if (sqlNode instanceof IfSqlNode) {
            Field contentsField = IfSqlNode.class.getDeclaredField("contents");
            contentsField.setAccessible(true);
            SqlNode contents = (SqlNode) contentsField.get(sqlNode);
            iterateSqlNode(contents, sqlNode, contentsField);
        }
        //TODO ...处理其他你需要处理的类型
    }

	//寻找自定义占位符
    private String tryParseCustomToken(String text) {
        List<TokenRecord> records = new LinkedList<>();
        int strLength = text.length();
        findCustom: for (int i = 0; i < strLength; i++) {
            if (text.charAt(i) == this.customToken && text.charAt(i + 1) == '{') {
                for (int j = i + 2; j < strLength; j++) {
                    if (text.charAt(j) == '}') {
                        TokenRecord record = new TokenRecord(i, i + 1, j);
                        String placeholderName = text.substring(i + 2, j);
                        record.placeholderName = placeholderName;
                        //注意,Mybatis会在首次加载Mapper的时候,把配置变量中存在的占位符先替换掉,而不是等到SQL执行的时候再替换
                        //例如,配置文件中有mybatis.configuration.variables.abc=xxx
                        //那么,Mybatis初始化的时候会把Mapper中的所有${abc}替换为xxx
                        //我们使用了自定义占位符,所以要替Mybatis完成这一步
                        if (this.mybatisVariables != null) {
                            record.placeholderValue = this.mybatisVariables.getProperty(placeholderName);
                        }
                        records.add(record);
                        continue findCustom;
                    }
                }
                throw new IllegalStateException("token not match");
            }
        }
        if (records.isEmpty()) {
            return null;
        }
        //生成替换后的sql字符串
        StringBuilder builder = new StringBuilder();
        //原text中将要被拼接的字符索引
        int appendIndex = 0;
        for (TokenRecord tr : records) {
            builder.append(text, appendIndex, tr.customTokenIndex);
            if (tr.placeholderValue != null) {
                builder.append(tr.placeholderValue);
            } else {
                builder.append('$').append('{');
                builder.append(tr.placeholderName);
                builder.append('}');
            }
            appendIndex = tr.endBracketIndex + 1;
        }
        builder.append(text, appendIndex, text.length() - 1);

        return builder.toString();
    }

	//把替换后的sqlNode保存到对应字段
    @SuppressWarnings("unchecked")
    public void saveNewNode(TextSqlNode newNode, SqlNode oldNode, Object targetObject, Field field) throws Exception {
        Class<?> fieldType = field.getType();
        if (List.class.isAssignableFrom(fieldType)) {
            List<Object> list = (List<Object>) field.get(targetObject);
            for (int i = 0; i < list.size(); i++) {
                if (list.get(i) == oldNode) {
                    list.set(i, newNode);
                    return;
                }
            }
        } else if (SqlNode.class.isAssignableFrom(fieldType)) {
            field.set(targetObject, newNode);
        }
    }

	//记录占位符替换信息的结构体
    private static class TokenRecord {
        int customTokenIndex;
        int startBracketIndex;
        int endBracketIndex;
        String placeholderName;
        String placeholderValue;

        TokenRecord(int customTokenIndex, int startBracketIndex, int endBracketIndex) {
            this.customTokenIndex = customTokenIndex;
            this.startBracketIndex = startBracketIndex;
            this.endBracketIndex = endBracketIndex;
        }
    }

}

如此,一个简单的替换器就完成了。
还有一点需要注意,如果你用@{}写完的sql不是DynamicSqlSource(一般来说就是sql标签中没有嵌套其他标签),那么需要实现上面handleRawSqlSource()方法。或者在@{}外加一层标签:

<if test="1 == 1">
	@{abc}
</if>

这样Mybatis会识别为DynamicSqlSource。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值