sql的mock
注入sql的拦截器,对查询语句拦截,按规则访问本地路径的文件,实现mock功能。
@Intercepts(value = {
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class})})
public class SqlResultInterceptor implements Interceptor {
private static Logger logger = LoggerFactory.getLogger(SqlResultInterceptor.class);
@Override
public Object intercept(Invocation invocation) throws Throwable {
MappedStatement mappedStatement = getInvokeSqlStatement(invocation);
String methodName = getStatementName(mappedStatement);
Class<?> returnType = invocation.getMethod().getReturnType();
Class<?> type = getResultType(mappedStatement);
System.out.println("sql method name is : " + methodName + " , return type is : " + returnType + " , result type is : " + type);
Function<String, Object> mockFunction = filePath -> getMockResult(filePath, type, returnType);
Function remoteFunction = s -> {
try {
return invocation.proceed();
} catch (Throwable e) {
logger.error("methodName={} sql remote error", methodName, e);
throw new RuntimeException(e);
}
};
Object result = TestFileHelper.getResult(methodName, mockFunction, remoteFunction);
return result;
}
private Object getMockResult(String filePath, Class<?> type, Class<?> returnType) {
String result = TestFileHelper.readFile(filePath);
if (type == null && returnType != null) {
return JSONObject.parseObject(result, returnType);
}
Object object = JSONObject.parseArray(result, type);
return object;
}
private Class<?> getResultType(MappedStatement mappedStatement) {
if (CollectionUtils.isNotEmpty(mappedStatement.getResultMaps())) {
Class<?> type = mappedStatement.getResultMaps().get(0).getType();
return type;
}
return null;
}
private String getStatementName(MappedStatement mappedStatement) {
String statementId = mappedStatement.getId();
int indexOf = statementId.lastIndexOf(".");
String name = statementId.substring(statementId.lastIndexOf(".", indexOf - 1) + 1);
return name.replace(".", "_");
}
private MappedStatement getInvokeSqlStatement(Invocation invocation) throws Exception {
Object target = invocation.getTarget();
if (target instanceof Executor) {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
return ms;
}
DefaultResultSetHandler statementHandler = (DefaultResultSetHandler) target;
try {
Field feild = statementHandler.getClass().getDeclaredField("mappedStatement");
feild.setAccessible(true);
MappedStatement mappedStatement = (MappedStatement) feild.get(statementHandler);
return mappedStatement;
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
}