[JAVA] mapper层sql校验

初始方案

  • 不校验script语法
    - sql
    mapper层sql校验,在项目启动前进行sql语法校验,通常要到执行这个mapper才会报错。
package ix.account.util;

import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.ibatis.annotations.*;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.context.ResourceLoaderAware;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.core.type.filter.TypeFilter;
import org.springframework.util.StringUtils;
import org.springframework.util.SystemPropertyUtils;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.*;
import java.util.stream.Collectors;

/**
 * spring scaner
 */
@Slf4j
public class ClassScaner implements ResourceLoaderAware {

    private final List<TypeFilter> includeFilters = new LinkedList<>();
    private final List<TypeFilter> excludeFilters = new LinkedList<>();

    private ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();
    private MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(this.resourcePatternResolver);

    public static Set<Class> scan(String[] basePackages,
                                  Class<? extends Annotation>... annotations) {
        ClassScaner classScaner = new ClassScaner();

        if (ArrayUtils.isNotEmpty(annotations)) {
            for (Class annotation : annotations) {
                classScaner.addIncludeFilter(new AnnotationTypeFilter(annotation));
            }
        }

        Set<Class> classes = new HashSet<>();
        for (String s : basePackages) {
            classes.addAll(classScaner.doScan(s));
        }
        return classes;
    }

    /**
     * spring 指定包扫描
     *
     * @param basePackages 扫描包基本路径
     * @param annotations  具体扫描什么注解 例如{@link Mapper}
     * @return
     */
    public static Set<Class> scan(String basePackages, Class<? extends Annotation>... annotations) {
        return ClassScaner.scan(StringUtils.tokenizeToStringArray(basePackages, ",; \t\n"), annotations);
    }

    public final ResourceLoader getResourceLoader() {
        return this.resourcePatternResolver;
    }

    @Override
    public void setResourceLoader(ResourceLoader resourceLoader) {
        this.resourcePatternResolver = ResourcePatternUtils
                .getResourcePatternResolver(resourceLoader);
        this.metadataReaderFactory = new CachingMetadataReaderFactory(
                resourceLoader);
    }

    public void addIncludeFilter(TypeFilter includeFilter) {
        this.includeFilters.add(includeFilter);
    }

    public void addExcludeFilter(TypeFilter excludeFilter) {
        this.excludeFilters.add(0, excludeFilter);
    }

    public void resetFilters(boolean defaultFilters) {
        this.includeFilters.clear();
        this.excludeFilters.clear();
    }

    public Set<Class> doScan(String basePackage) {
        Set<Class> classes = new HashSet<>();
        try {
            String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX
                    + org.springframework.util.ClassUtils
                    .convertClassNameToResourcePath(SystemPropertyUtils
                            .resolvePlaceholders(basePackage))
                    + "/**/*.class";
            Resource[] resources = this.resourcePatternResolver
                    .getResources(packageSearchPath);

            for (int i = 0; i < resources.length; i++) {
                Resource resource = resources[i];
                if (resource.isReadable()) {
                    MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(resource);
                    boolean b = (includeFilters.size() == 0 && excludeFilters.size() == 0)
                            || matches(metadataReader);


                    if (b) {
                        try {
                            classes.add(Class.forName(metadataReader
                                    .getClassMetadata().getClassName()));
                        } catch (ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        } catch (IOException ex) {
            throw new BeanDefinitionStoreException(
                    "I/O failure during classpath scanning", ex);
        }
        return classes;
    }

    protected boolean matches(MetadataReader metadataReader) throws IOException {
        for (TypeFilter tf : this.excludeFilters) {
            if (tf.match(metadataReader, this.metadataReaderFactory)) {
                return false;
            }
        }
        for (TypeFilter tf : this.includeFilters) {
            if (tf.match(metadataReader, this.metadataReaderFactory)) {
                return true;
            }
        }
        return false;
    }


    public static boolean getMethodAnnotation(String basePackages,
                                              Class<? extends Annotation>... annotations) {
        Set<Class> scan = scan(basePackages, annotations);

        List<SqlErrorInfo> sqlErrorInfos = new ArrayList<>();
        for (Class mapperClass : scan) {
            Method[] methods = mapperClass.getMethods();

            for (Method method : methods) {
                Annotation[] annotations1 = method.getAnnotations();


                for (Annotation annotation : annotations1) {
                    if (annotation instanceof Insert) {
                        List<String> collect = Arrays.stream(((Insert) annotation).value()).collect(Collectors.toList());
                        String sql = sqlAnnotValue(collect);
                        boolean b = crudCheck(sql);

                        if (b == false) {
                            sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
                        }
                    } else if (annotation instanceof Select) {
                        List<String> collect = Arrays.stream(((Select) annotation).value()).collect(Collectors.toList());
                        String sql = sqlAnnotValue(collect);
                        boolean b = crudCheck(sql);

                        if (b == false) {
                            sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
                        }
                    } else if (annotation instanceof Update) {
                        List<String> collect = Arrays.stream(((Update) annotation).value()).collect(Collectors.toList());
                        String sql = sqlAnnotValue(collect);
                        boolean b = crudCheck(sql);

                        if (b == false) {
                            sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
                        }
                    } else if (annotation instanceof Delete) {
                        List<String> collect = Arrays.stream(((Delete) annotation).value()).collect(Collectors.toList());
                        String sql = sqlAnnotValue(collect);
                        boolean b = crudCheck(sql);

                        if (b == false) {
                            sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
                        }
                    }
                }
            }
        }
//        System.out.println(sqlErrorInfos.size());
        sqlErrorInfos.forEach(
                info -> {
                    log.error("不正确的sql,不校验<script>包装,错误sql : " + info);
                }
        );
        if (sqlErrorInfos.size() == 0) {
            return true;
        } else {
            return false;
        }

    }

    /**
     * 将Mapper层注解中的sql获取
     *
     * @param collect sqlCollect
     * @return sql
     */
    private static String sqlAnnotValue(List<String> collect) {
        String sql;
        if (collect.size() != 1) {
            StringBuilder sbd = new StringBuilder();
            collect.forEach(s -> {
                sbd.append(s);
                sbd.append(" ");
            });
            sql = sbd.toString();
        } else {

            sql = collect.get(0);
        }
        return sql;
    }

    /**
     * crud sql校验
     *
     * @param sql
     */
    private static boolean crudCheck(String sql) {

//        System.out.println("准备校验的sql =  " + sql);

        if (sql.startsWith("<script>")) {
            return true;
        } else {
            try {

                MySqlStatementParser parser = new MySqlStatementParser(sql);
                List<SQLStatement> stmtList = parser.parseStatementList();
                int size = stmtList.size();
                if (size != 0) {
                    return true;
                } else {
                    return false;
                }
            } catch (Exception e) {
                return false;
            }
        }
    }

    public static void main(String[] args) {
        String basePackages = "ix.account.mapper";
        Set<Class> scan = ClassScaner.scan(basePackages, Mapper.class);

        getMethodAnnotation(basePackages, Mapper.class);


    }


    @Data
    @Builder
    private static class SqlErrorInfo {
        private String method;
        private String sql;
        private String clazz;
    }
}

继续改进

查看源码部分
org.apache.ibatis.scripting.xmltags.XMLLanguageDriver#createSqlSource(org.apache.ibatis.session.Configuration, java.lang.String, java.lang.Class<?>)
在这里插入图片描述
这部分代码中script变量为注解中的脚本。

  @Override
  public SqlSource createSqlSource(Configuration configuration, String script, Class<?> parameterType) {
    // issue #3
    if (script.startsWith("<script>")) {
      XPathParser parser = new XPathParser(script, false, configuration.getVariables(), new XMLMapperEntityResolver());
      return createSqlSource(configuration, parser.evalNode("/script"), parameterType);
    } else {
      // issue #127
      script = PropertyParser.parse(script, configuration.getVariables());
      TextSqlNode textSqlNode = new TextSqlNode(script);
      if (textSqlNode.isDynamic()) {
        return new DynamicSqlSource(configuration, textSqlNode);
      } else {
        return new RawSqlSource(configuration, script, parameterType);
      }
    }
  }

org.apache.ibatis.scripting.xmltags.XMLLanguageDriver#createSqlSource(org.apache.ibatis.session.Configuration, org.apache.ibatis.parsing.XNode, java.lang.Class<?>)

  @Override
  public SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
    XMLScriptBuilder builder = new XMLScriptBuilder(configuration, script, parameterType);
    return builder.parseScriptNode();
  }

在这里插入图片描述
此时XNode script已经成为可用的sql

  • 基本改造如下
        String s = "<script>" +
                "SELECT * FROM tbl_order " +
                "WHERE 1=1" +
                "<when test='title!=null'>" +
                "AND mydate = #{mydate}" +
                "</when>" +
                "</script>";

        XPathParser parser = new XPathParser(s, false);
        XNode xNode = parser.evalNode("/script");
        String stringBody = xNode.getStringBody();
        System.out.println(stringBody);

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值