在日常系统维护工作中,我们经常需要处理与数据库交互的复杂逻辑。由于MyBatis框架中包含众多条件判断语句,这可能在自测过程中导致某些条件被忽略,导致未能发现SQL语句中的潜在语法错误。特别是在维护多年的系统。有时候在生产环境也会暴露SQL语法错误。
为了解决这一问题,我们可以开发一款工具,用以自动检测MyBatis中所有SQL语句的语法正确性。该工具的实现策略如下:
- 利用Spring框架的功能,定位并加载指定目录下的所有Mapper接口实现类。
- 运用反射技术,分析Mapper接口中定义的方法及其参数,动态创建相应的参数对象。
- 调用这些方法,以确保MyBatis中的所有条件逻辑均得到正确执行和验证。
/**
*
* 此工具可检测ibatis中的sql语法。
*/
@Slf4j
public class MyBatisSqlTool {
/**
* 填充对象时的值
**/
private static String STR_AAA = "AAA";
private static Integer INT_111 = 101;
private static Long LONG_222 = 333L;
private static String START_TIME = "2024-10-10 12:12:12";
private static String END_TIME = "2024-11-11 11:11:11";
/**
* 统计调用的mapper方法数量
*/
private int totalMethod = 0;
private int errorMethod = 0;
/**
* 不检查的mapper方法
*/
private static List<String> excludeMethodList = new ArrayList<String>();
static {
//List中套List
excludeMethodList.add("findUserNotInList");
}
private String objectString;
/**
* 检查指定mapperName和methodName的方法是否存在并执行。
* @param mapperName 需要检查的mapper名称。
* @param methodName 需要检查的方法名称,如果为空则检查所有方法。
* @return 如果方法存在并成功执行,返回包含执行结果的ResponseDto对象;否则抛出异常。
*/
public ResponseDto<Object> check(String mapperName, String methodName) {
String startTime = DateUtils.getCurrentDate();
Object bean1 = SpringContextUtil.getBean(mapperName);
Method[] methodStream = bean1.getClass().getDeclaredMethods();
for (int i = 0; i < methodStream.length; i++) {
if (StringUtils.isBlank(methodName) || methodStream[i].getName().equals(methodName)) {
log.info("方法名称:" + methodStream[i].getName());
invokeMapperMethod(bean1, methodStream[i]);
}
}
String endTime = DateUtils.getCurrentDate();
return ResponseDto.success(new MyBatisCheckResultDto(totalMethod, errorMethod, startTime, endTime));
}
/**
* 校验全部的Mapper
*
* @return
*/
public ResponseDto<Object> checkAll() {
List<String> excludeBeanList = excludeBeanList();
String startTime = DateUtils.getCurrentDate();
String[] beanDefinitionNames = SpringContextUtil.getApplicationContext().getBeanDefinitionNames();
for (int i = 0; i < beanDefinitionNames.length; i++) {
Object bean = SpringContextUtil.getApplicationContext().getBean(beanDefinitionNames[i]);
Class<?>[] interfaces = bean.getClass().getInterfaces();
if (interfaces.length == 0) {
continue;
}
if (interfaces[0].getName().contains(".xxx.dao.")) {
log.info(":mapper:" + beanDefinitionNames[i]);
log.info(bean.getClass().getInterfaces()[0].getName());
if (excludeBeanList.contains(beanDefinitionNames[i])) {
log.info("exclude mapper, {}", beanDefinitionNames[i]);
continue;
}
callBean(bean);
}
}
String endTime = DateUtils.getCurrentDate();
return ResponseDto.success(new MyBatisCheckResultDto(totalMethod, errorMethod, startTime, endTime));
}
@NotNull
/**
* 不检查的mapper对象
*/
private List<String> excludeBeanList() {
List<String> excludeBeanList = new ArrayList<String>();
excludeBeanList.add("userMapper");
return excludeBeanList;
}
private ResponseDto<Object> callBean(Object bean1) {
Method[] methodStream = bean1.getClass().getInterfaces()[0].getDeclaredMethods();
for (int i = 0; i < methodStream.length; i++) {
String methodName = methodStream[i].getName();
log.info("方法名称: " + methodName);
if (excludeMethodList.contains(methodName)) {
log.info("排除方法: " + methodName);
continue;
}
invokeMapperMethod(bean1, methodStream[i]);
}
return ResponseDto.success();
}
/**
* 调用mapper方法
*/
private void invokeMapperMethod(Object bean, Method methodStream) {
try {
totalMethod = totalMethod + 1;
Method declaredMethods = getMethod(bean, methodStream);
List<Object> params2 = getMethodParams(declaredMethods);
methodStream.invoke(bean, params2.toArray());
} catch (Exception e) {
errorMethod = errorMethod + 1;
log.info("error bean. method name={}", methodStream.getName());
log.error("error.", e);
}
}
@Nullable
/**
* 判断当前bean1中是否有要校验的方法。
*/
private static Method getMethod(Object bean1, Method methodStream) throws ClassNotFoundException {
String beanName = bean1.getClass().getInterfaces()[0].getName();
Class<?>[] parameterTypes = methodStream.getParameterTypes();
Method[] declaredMethods = Class.forName(beanName).getDeclaredMethods();
for (int i = 0; i < declaredMethods.length; i++) {
boolean nameSame = declaredMethods[i].getName().equals(methodStream.getName());
int paramLength = declaredMethods[i].getGenericParameterTypes().length;
boolean paramEquals = (paramLength == parameterTypes.length);
if (nameSame && paramEquals) {
return declaredMethods[i];
}
}
return null;
}
private String randomString() {
return STR_AAA + RandomStringUtils.randomAlphanumeric(5);
}
}
获取方法的入参类型,并赋值。
private List<Object> getMethodParams(Method methodStream) {
Type[] genericParameterTypes = methodStream.getGenericParameterTypes();
List<Object> params = new ArrayList<Object>();
for (int j = 0; j < genericParameterTypes.length; j++) {
try {
Type genericParameterType = genericParameterTypes[j];
if (genericParameterType instanceof ParameterizedType) {
//处理泛型
ParameterizedType parameterizedType = (ParameterizedType) genericParameterType;
String typeName = parameterizedType.getRawType().getTypeName();
if ("java.util.List".equals(typeName)) {
Type actualTypeArgument = parameterizedType.getActualTypeArguments()[0];
List objectList = new ArrayList<>();
Object obj = Class.forName(actualTypeArgument.getTypeName()).newInstance();
obj = fillObject(obj);
objectList.add(obj);
params.add(objectList);
} else if ("java.util.Map".equals(typeName) || "java.util.HashMap".equals(typeName)) {
Map map = new HashMap<>(1);
params.add(map);
//throw new RuntimeException("not support map");
} else if ("java.util.Set".equals(typeName) || "java.util.HashSet".equals(typeName)) {
Set set = new HashSet<>(1);
set.add(randomString());
params.add(set);
}
} else {
String typeName = genericParameterType.getTypeName();
if ("int".equals(typeName) || "java.lang.Integer".equals(typeName)) {
params.add(INT_111);
} else if ("boolean".equals(typeName) || "java.lang.Boolean".equals(typeName)) {
params.add(Boolean.TRUE);
} else if ("long".equals(typeName) || "java.lang.Long".equals(typeName)) {
params.add(LONG_222);
} else if ("java.lang.String[]".equals(typeName)) {
String[] strings = {randomString()};
params.add(strings);
} else if ("java.util.Date".equals(typeName)) {
params.add(new Date());
} else {
Object obj = ((Class) genericParameterType).newInstance();
obj = fillObject(obj);
params.add(obj);
}
}
} catch (Exception e) {
if (e.getMessage().indexOf("not support map") < 0) {
log.error("error method.", e);
}
throw new RuntimeException(e);
}
}
return params;
}
生成java对象,并对属性赋值。
@NotNull
private Object fillObject(Object obj) throws IllegalAccessException {
if (obj instanceof String) {
return randomString();
}
if (obj instanceof Integer) {
return INT_111;
}
Field[] declaredFields = obj.getClass().getDeclaredFields();
Class<?> superclass = obj.getClass().getSuperclass();
Field[] declaredFieldsAll;
objectString = "java.lang.Object";
if (!objectString.equals(superclass.getName())) {
Field[] declaredFields1 = superclass.getDeclaredFields();
declaredFieldsAll = ArrayUtils.addAll(declaredFields, declaredFields1);
} else {
declaredFieldsAll = declaredFields;
}
for (int k = 0; k < declaredFieldsAll.length; k++) {
declaredFieldsAll[k].setAccessible(true);
String simpleNameField = declaredFieldsAll[k].getType().getSimpleName();
if ("String".equals(simpleNameField)) {
if ("beginTime".equals(declaredFieldsAll[k].getName()) || "startTime".equals(declaredFieldsAll[k].getName())) {
declaredFieldsAll[k].set(obj, START_TIME);
continue;
}
if ("endTime".equals(declaredFieldsAll[k].getName()) || "endTime".equals(declaredFieldsAll[k].getName())) {
declaredFieldsAll[k].set(obj, END_TIME);
continue;
}
declaredFieldsAll[k].set(obj, randomString());
}
if ("Integer".equals(simpleNameField)) {
declaredFieldsAll[k].set(obj, INT_111);
}
if ("List".equals(simpleNameField)) {
List<String> stringList = new ArrayList<String>();
stringList.add(randomString());
declaredFields[k].set(obj, stringList);
}
if ("Date".equals(simpleNameField)) {
declaredFieldsAll[k].set(obj, new Date());
}
}
return obj;
}