部分代码来自mybatis分页和spring-boot整合项目,项目地址是:http://git.oschina.net/free/Mybatis_PageHelper
mybatis提供了拦截器Interceptor接口
public interface Interceptor {
Object intercept(Invocation invocation) throws Throwable;
Object plugin(Object target);
void setProperties(Properties properties);
}
分页相关的拦截器实现类是PageInterceptor,@Intercepts注解标识了需要拦截的类(接口)和方法信息
Configuration类中提供了若干Handler的实例方法,包括:Executor,StatementHandler,ResultSetHandler,ParameterHandler
executor = (Executor) interceptorChain.pluginAll(executor);
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
生成这些实例前要走一遍拦截器链,依次执行plugin方法,target则是最终返回的实例对象
public Object pluginAll(Object target) {
for (Interceptor interceptor : interceptors) {
target = interceptor.plugin(target);
}
return target;
}
下面看一下分页拦截器PageInterceptor的plugin方法
Plugin实现了动态代理相关的InvocationHandler接口。
首先获取拦截器定义的拦截类和方法信息,然后判断拦截对象是否实现了该接口并且拥有该方法,如果满足条件,则通过Proxy动态代理生成代理对象返回
public static Object wrap(Object target, Interceptor interceptor) {
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
Class<?> type = target.getClass();
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
if (interfaces.length > 0) {
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(target, interceptor, signatureMap));
}
return target;
}
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
// issue #251
if (interceptsAnnotation == null) {
throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
for (Signature sig : sigs) {
Set<Method> methods = signatureMap.get(sig.type());
if (methods == null) {
methods = new HashSet<Method>();
signatureMap.put(sig.type(), methods);
}
try {
Method method = sig.type().getMethod(sig.method(), sig.args());
methods.add(method);
} catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
}
}
return signatureMap;
}
private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
Set<Class<?>> interfaces = new HashSet<Class<?>>();
while (type != null) {
for (Class<?> c : type.getInterfaces()) {
if (signatureMap.containsKey(c)) {
interfaces.add(c);
}
}
type = type.getSuperclass();
}
return interfaces.toArray(new Class<?>[interfaces.size()]);
}
Plugin的invoke()方法
如果拦截对象调用的方法属于拦截范围,则调用拦截器对象的intercept()方法,进行分页相关的sql语句包装
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
if (methods != null && methods.contains(method)) {
return interceptor.intercept(new Invocation(target, method, args));
}
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
try {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameter = args[1];
RowBounds rowBounds = (RowBounds) args[2];
ResultHandler resultHandler = (ResultHandler) args[3];
Executor executor = (Executor) invocation.getTarget();
CacheKey cacheKey;
BoundSql boundSql;
//由于逻辑关系,只会进入一次
if(args.length == 4){
//4 个参数时
boundSql = ms.getBoundSql(parameter);
cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
} else {
//6 个参数时
cacheKey = (CacheKey) args[4];
boundSql = (BoundSql) args[5];
}
List resultList;
//调用方法判断是否需要进行分页,如果不需要,直接返回结果
if (!dialect.skip(ms, parameter, rowBounds)) {
//反射获取动态参数
String msId = ms.getId();
Configuration configuration = ms.getConfiguration();
Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
//判断是否需要进行 count 查询
if (dialect.beforeCount(ms, parameter, rowBounds)) {
String countMsId = msId + countSuffix;
Long count;
//先判断是否存在手写的 count 查询
MappedStatement countMs = getExistedMappedStatement(configuration, countMsId);
if(countMs != null){
count = executeManualCount(executor, countMs, parameter, boundSql, resultHandler);
} else {
countMs = msCountMap.get(countMsId);
//自动创建
if (countMs == null) {
//根据当前的 ms 创建一个返回值为 Long 类型的 ms
countMs = MSUtils.newCountMappedStatement(ms, countMsId);
msCountMap.put(countMsId, countMs);
}
count = executeAutoCount(executor, countMs, parameter, boundSql, rowBounds, resultHandler);
}
//处理查询总数
//返回 true 时继续分页查询,false 时直接返回
if (!dialect.afterCount(count, parameter, rowBounds)) {
//当查询总数为 0 时,直接返回空的结果
return dialect.afterPage(new ArrayList(), parameter, rowBounds);
}
}
//判断是否需要进行分页查询
if (dialect.beforePage(ms, parameter, rowBounds)) {
//生成分页的缓存 key
CacheKey pageKey = cacheKey;
//处理参数对象
parameter = dialect.processParameterObject(ms, parameter, boundSql, pageKey);
//调用方言获取分页 sql
String pageSql = dialect.getPageSql(ms, boundSql, parameter, rowBounds, pageKey);
BoundSql pageBoundSql = new BoundSql(configuration, pageSql, boundSql.getParameterMappings(), parameter);
//设置动态参数
for (String key : additionalParameters.keySet()) {
pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
//执行分页查询
resultList = executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, pageKey, pageBoundSql);
} else {
//不执行分页的情况下,也不执行内存分页
resultList = executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, cacheKey, boundSql);
}
} else {
//rowBounds用参数值,不使用分页插件处理时,仍然支持默认的内存分页
resultList = executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
}
return dialect.afterPage(resultList, parameter, rowBounds);
} finally {
dialect.afterAll();
}
}
Dialect是数据库方言接口,其中定义了分页查询前中后等操作的方法
/**
* 数据库方言,针对不同数据库进行实现
*
* @author liuzh
*/
public interface Dialect {
/**
* 跳过 count 和 分页查询
*
* @param ms MappedStatement
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @return true 跳过,返回默认查询结果,false 执行分页查询
*/
boolean skip(MappedStatement ms, Object parameterObject, RowBounds rowBounds);
/**
* 执行分页前,返回 true 会进行 count 查询,false 会继续下面的 beforePage 判断
*
* @param ms MappedStatement
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @return
*/
boolean beforeCount(MappedStatement ms, Object parameterObject, RowBounds rowBounds);
/**
* 生成 count 查询 sql
*
* @param ms MappedStatement
* @param boundSql 绑定 SQL 对象
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @param countKey count 缓存 key
* @return
*/
String getCountSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey);
/**
* 执行完 count 查询后
*
* @param count 查询结果总数
* @param parameterObject 接口参数
* @param rowBounds 分页参数
* @return true 继续分页查询,false 直接返回
*/
boolean afterCount(long count, Object parameterObject, RowBounds rowBounds);
/**
* 处理查询参数对象
*
* @param ms MappedStatement
* @param parameterObject
* @param boundSql
* @param pageKey
* @return
*/
Object processParameterObject(MappedStatement ms, Object parameterObject, BoundSql boundSql, CacheKey pageKey);
/**
* 执行分页前,返回 true 会进行分页查询,false 会返回默认查询结果
*
* @param ms MappedStatement
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @return
*/
boolean beforePage(MappedStatement ms, Object parameterObject, RowBounds rowBounds);
/**
* 生成分页查询 sql
*
* @param ms MappedStatement
* @param boundSql 绑定 SQL 对象
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @param pageKey 分页缓存 key
* @return
*/
String getPageSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey);
/**
* 分页查询后,处理分页结果,拦截器中直接 return 该方法的返回值
*
* @param pageList 分页查询结果
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @return
*/
Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds);
/**
* 完成所有任务后
*/
void afterAll();
/**
* 设置参数
*
* @param properties 插件属性
*/
void setProperties(Properties properties);
}