自动化MyBatis SQL语法检测工具

在日常系统维护工作中,我们经常需要处理与数据库交互的复杂逻辑。由于MyBatis框架中包含众多条件判断语句,这可能在自测过程中导致某些条件被忽略,导致未能发现SQL语句中的潜在语法错误。特别是在维护多年的系统。有时候在生产环境也会暴露SQL语法错误。
为了解决这一问题,我们可以开发一款工具,用以自动检测MyBatis中所有SQL语句的语法正确性。该工具的实现策略如下:

  1. 利用Spring框架的功能,定位并加载指定目录下的所有Mapper接口实现类。
  2. 运用反射技术,分析Mapper接口中定义的方法及其参数,动态创建相应的参数对象。
  3. 调用这些方法,以确保MyBatis中的所有条件逻辑均得到正确执行和验证。

/**
 * 
 * 此工具可检测ibatis中的sql语法。
 */
@Slf4j
public class MyBatisSqlTool {

    /**
    * 填充对象时的值
    **/
    private static String STR_AAA = "AAA";
    private static Integer INT_111 = 101;
    private static Long LONG_222 = 333L;
    private static String START_TIME = "2024-10-10 12:12:12";
    private static String END_TIME = "2024-11-11 11:11:11";

    /**
    * 统计调用的mapper方法数量
    */
    private int totalMethod = 0;
    private int errorMethod = 0;

    /**
    * 不检查的mapper方法
    */
    private static List<String> excludeMethodList = new ArrayList<String>();

    static {
        //List中套List
        excludeMethodList.add("findUserNotInList");
    }

    private String objectString;

    /**
     * 检查指定mapperName和methodName的方法是否存在并执行。
     * @param mapperName 需要检查的mapper名称。
     * @param methodName 需要检查的方法名称,如果为空则检查所有方法。
     * @return 如果方法存在并成功执行,返回包含执行结果的ResponseDto对象;否则抛出异常。
     */
    public ResponseDto<Object> check(String mapperName, String methodName) {
        String startTime = DateUtils.getCurrentDate();
        Object bean1 = SpringContextUtil.getBean(mapperName);
        Method[] methodStream = bean1.getClass().getDeclaredMethods();
        for (int i = 0; i < methodStream.length; i++) {
            if (StringUtils.isBlank(methodName) || methodStream[i].getName().equals(methodName)) {
                log.info("方法名称:" + methodStream[i].getName());
                invokeMapperMethod(bean1, methodStream[i]);
            }
        }
        String endTime = DateUtils.getCurrentDate();
        return ResponseDto.success(new MyBatisCheckResultDto(totalMethod, errorMethod, startTime, endTime));
    }

    /**
     * 校验全部的Mapper
     *
     * @return
     */
    public ResponseDto<Object> checkAll() {
        List<String> excludeBeanList = excludeBeanList();

        String startTime = DateUtils.getCurrentDate();
        String[] beanDefinitionNames = SpringContextUtil.getApplicationContext().getBeanDefinitionNames();
        for (int i = 0; i < beanDefinitionNames.length; i++) {

            Object bean = SpringContextUtil.getApplicationContext().getBean(beanDefinitionNames[i]);
            Class<?>[] interfaces = bean.getClass().getInterfaces();
            if (interfaces.length == 0) {
                continue;
            }
            if (interfaces[0].getName().contains(".xxx.dao.")) {
                log.info(":mapper:" + beanDefinitionNames[i]);
                log.info(bean.getClass().getInterfaces()[0].getName());
                if (excludeBeanList.contains(beanDefinitionNames[i])) {
                    log.info("exclude mapper, {}", beanDefinitionNames[i]);
                    continue;
                }
                callBean(bean);
            }
        }
        String endTime = DateUtils.getCurrentDate();
        return ResponseDto.success(new MyBatisCheckResultDto(totalMethod, errorMethod, startTime, endTime));
    }

    @NotNull
    /**
    * 不检查的mapper对象
    */
    private List<String> excludeBeanList() {
        List<String> excludeBeanList = new ArrayList<String>();
        excludeBeanList.add("userMapper");
        return excludeBeanList;
    }

    private ResponseDto<Object> callBean(Object bean1) {
        Method[] methodStream = bean1.getClass().getInterfaces()[0].getDeclaredMethods();
        for (int i = 0; i < methodStream.length; i++) {
            String methodName = methodStream[i].getName();

            log.info("方法名称: " + methodName);
            if (excludeMethodList.contains(methodName)) {
                log.info("排除方法: " + methodName);
                continue;
            }
            invokeMapperMethod(bean1, methodStream[i]);
        }
        return ResponseDto.success();
    }

    /**
    * 调用mapper方法
    */
    private void invokeMapperMethod(Object bean, Method methodStream) {
        try {
            totalMethod = totalMethod + 1;
            Method declaredMethods = getMethod(bean, methodStream);
            List<Object> params2 = getMethodParams(declaredMethods);

            methodStream.invoke(bean, params2.toArray());
        } catch (Exception e) {
            errorMethod = errorMethod + 1;
            log.info("error bean. method name={}", methodStream.getName());
            log.error("error.", e);
        }
    }

    @Nullable
    /**
     * 判断当前bean1中是否有要校验的方法。
     */
    private static Method getMethod(Object bean1, Method methodStream) throws ClassNotFoundException {
        String beanName = bean1.getClass().getInterfaces()[0].getName();
        Class<?>[] parameterTypes = methodStream.getParameterTypes();
        Method[] declaredMethods = Class.forName(beanName).getDeclaredMethods();
        for (int i = 0; i < declaredMethods.length; i++) {
            boolean nameSame = declaredMethods[i].getName().equals(methodStream.getName());
            int paramLength = declaredMethods[i].getGenericParameterTypes().length;
            boolean paramEquals = (paramLength == parameterTypes.length);
            if (nameSame && paramEquals) {
                return declaredMethods[i];
            }
        }
        return null;
    }

    private String randomString() {
        return STR_AAA + RandomStringUtils.randomAlphanumeric(5);
    }
}

获取方法的入参类型,并赋值。

private List<Object> getMethodParams(Method methodStream) {
        Type[] genericParameterTypes = methodStream.getGenericParameterTypes();
        List<Object> params = new ArrayList<Object>();
        for (int j = 0; j < genericParameterTypes.length; j++) {
            try {
                Type genericParameterType = genericParameterTypes[j];
                if (genericParameterType instanceof ParameterizedType) {
                    //处理泛型
                    ParameterizedType parameterizedType = (ParameterizedType) genericParameterType;
                    String typeName = parameterizedType.getRawType().getTypeName();
                    if ("java.util.List".equals(typeName)) {
                        Type actualTypeArgument = parameterizedType.getActualTypeArguments()[0];
                        List objectList = new ArrayList<>();
                        Object obj = Class.forName(actualTypeArgument.getTypeName()).newInstance();
                        obj = fillObject(obj);
                        objectList.add(obj);
                        params.add(objectList);
                    } else if ("java.util.Map".equals(typeName) || "java.util.HashMap".equals(typeName)) {
                        Map map = new HashMap<>(1);
                        params.add(map);
                        //throw new RuntimeException("not support map");
                    } else if ("java.util.Set".equals(typeName) || "java.util.HashSet".equals(typeName)) {
                        Set set = new HashSet<>(1);
                        set.add(randomString());
                        params.add(set);
                    }
                } else {
                    String typeName = genericParameterType.getTypeName();
                    if ("int".equals(typeName) || "java.lang.Integer".equals(typeName)) {
                        params.add(INT_111);
                    } else if ("boolean".equals(typeName) || "java.lang.Boolean".equals(typeName)) {
                        params.add(Boolean.TRUE);
                    } else if ("long".equals(typeName) || "java.lang.Long".equals(typeName)) {
                        params.add(LONG_222);
                    } else if ("java.lang.String[]".equals(typeName)) {
                        String[] strings = {randomString()};
                        params.add(strings);
                    } else if ("java.util.Date".equals(typeName)) {
                        params.add(new Date());
                    } else {
                        Object obj = ((Class) genericParameterType).newInstance();
                        obj = fillObject(obj);
                        params.add(obj);
                    }
                }
            } catch (Exception e) {
                if (e.getMessage().indexOf("not support map") < 0) {
                    log.error("error method.", e);
                }
                throw new RuntimeException(e);
            }
        }
        return params;
    }

生成java对象,并对属性赋值。

@NotNull
    private Object fillObject(Object obj) throws IllegalAccessException {
        if (obj instanceof String) {
            return randomString();
        }
        if (obj instanceof Integer) {
            return INT_111;
        }
        Field[] declaredFields = obj.getClass().getDeclaredFields();
        Class<?> superclass = obj.getClass().getSuperclass();
        Field[] declaredFieldsAll;
        objectString = "java.lang.Object";
        if (!objectString.equals(superclass.getName())) {
            Field[] declaredFields1 = superclass.getDeclaredFields();
            declaredFieldsAll = ArrayUtils.addAll(declaredFields, declaredFields1);
        } else {
            declaredFieldsAll = declaredFields;
        }

        for (int k = 0; k < declaredFieldsAll.length; k++) {
            declaredFieldsAll[k].setAccessible(true);
            String simpleNameField = declaredFieldsAll[k].getType().getSimpleName();
            if ("String".equals(simpleNameField)) {
                if ("beginTime".equals(declaredFieldsAll[k].getName()) || "startTime".equals(declaredFieldsAll[k].getName())) {
                    declaredFieldsAll[k].set(obj, START_TIME);
                    continue;
                }
                if ("endTime".equals(declaredFieldsAll[k].getName()) || "endTime".equals(declaredFieldsAll[k].getName())) {
                    declaredFieldsAll[k].set(obj, END_TIME);
                    continue;
                }
                declaredFieldsAll[k].set(obj, randomString());
            }
            if ("Integer".equals(simpleNameField)) {
                declaredFieldsAll[k].set(obj, INT_111);
            }
            if ("List".equals(simpleNameField)) {
                List<String> stringList = new ArrayList<String>();
                stringList.add(randomString());
                declaredFields[k].set(obj, stringList);
            }
            if ("Date".equals(simpleNameField)) {
                declaredFieldsAll[k].set(obj, new Date());
            }
        }
        return obj;
    }
Mybatis SQL语法包括动态SQL技术和XML解析。动态SQL技术是一种根据特定条件动态拼装SQL语句的功能,它解决了拼接SQL语句字符串时的痛点问题。在Mybatis中,SQL语句通常写在mapper.xml文件中,但是XML解析时会遇到特殊字符需要进行转义处理,例如使用<代替<,>代替>,&代替&,&apos;代替',"代替"等。另外,#{}在Mybatis中用于向prepareStatement中的预处理语句中设计参数值,可以理解为一个占位符即?。所以,Mybatis SQL语法是通过动态SQL技术拼装SQL语句,并在mapper.xml中进行XML解析,并且使用#{}作为参数的占位符。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [MyBatis基础语法详解,真的全面](https://blog.csdn.net/qq_42176665/article/details/127873388)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [Mybatis常用语法汇总](https://blog.csdn.net/qw463800202/article/details/103221651)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值