阅读须知
- Mybatis源码版本:3.4.4
- 文章中使用/* */注释的方法会做深入分析
分页拦截器DEMO
本篇文章我们来分析Mybatis拦截器的源码,进入源码分析之前,我们先来看一个Mybatis拦截器实际应用的小例子 — 分页拦截器:
/**
* 分页查询对象,走分页拦截器时使用
*/
public class PageCondition {
private int totalCount;
private int totalPage;
private int currentPage;
private int pageSize;
public PageCondition() {
}
public PageCondition(int currentPage, int pageSize) {
this.currentPage = transferCurrentPage(currentPage);
this.pageSize = pageSize;
}
/**
* 转换当前页,页面当前页1开始,sql当前页0开始
*/
private int transferCurrentPage(int currentPage){
if(currentPage > 0){
currentPage -= 1;
}
return currentPage;
}
/**
* 恢复currentPage
*/
public void recoveryCurrentPage(){
currentPage += 1;
}
public int getTotalCount() {
return totalCount;
}
public void setTotalCount(int totalCount) {
this.totalCount = totalCount;
this.totalPage = this.totalCount % this.pageSize == 0 ? this.totalCount / this.pageSize : this.totalCount / this.pageSize + 1;
}
public int getTotalPage() {
return totalPage;
}
public int getCurrentPage() {
return currentPage;
}
public void setCurrentPage(int currentPage) {
this.currentPage = transferCurrentPage(currentPage);
}
public int getPageSize() {
return pageSize;
}
public void setPageSize(int pageSize) {
this.pageSize = pageSize;
}
public void setTotalPage(int totalPage) {
this.totalPage = totalPage;
}
@Override
public String toString() {
return "PageCondition{" +
"totalCount=" + totalCount +
", totalPage=" + totalPage +
", currentPage=" + currentPage +
", pageSize=" + pageSize +
'}';
}
}
/**
* 分页拦截器
*/
@Intercepts(@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class}))
public class PageInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(PageInterceptor.class);
private static final String BOUND_SQL_KEY = "delegate.boundSql.sql";
private static final String PARAMETER_HANDLER_KEY = "delegate.parameterHandler";
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
BoundSql boundSql = statementHandler.getBoundSql();
PageCondition page = getPage(boundSql);
if (page == null) {
return invocation.proceed();
}
String originSql = boundSql.getSql();
String pageSql = getPageSql(originSql, page);
// 反射修改原sql
metaObject.setValue(BOUND_SQL_KEY, pageSql);
setPageTotalCount(originSql, metaObject, invocation, page);
return invocation.proceed();
}
/**
* 设置分页对象的总数
*/
private void setPageTotalCount(String originSql, MetaObject metaObject, Invocation invocation, PageCondition page) throws Exception {
PreparedStatement countStatement = null;
ResultSet resultSet = null;
try {
String countSql = getCountSql(originSql);
// 这个数据库连接不能关,后面的分页sql也要使用这个connection
Connection connection = (Connection) invocation.getArgs()[0];
countStatement = connection.prepareStatement(countSql);
ParameterHandler parameterHandler = (ParameterHandler) metaObject.getValue(PARAMETER_HANDLER_KEY);
parameterHandler.setParameters(countStatement);
resultSet = countStatement.executeQuery();
int totalCount = 0;
if (resultSet.next()) {
totalCount = resultSet.getInt(1);
}
page.setTotalCount(totalCount);
} finally {
if (resultSet != null) {
resultSet.close();
}
if (countStatement != null) {
countStatement.close();
}
}
}
/**
* 这种方式比较简陋,使用时注意sql语句末的分号问题
*/
private String getPageSql(String originSql, PageCondition page) {
return new StringBuilder().append(originSql).append(" LIMIT ").
append(page.getCurrentPage() * page.getPageSize()).append(",").append(page.getPageSize()).toString();
}
/**
* 获取分页对象
*/
private PageCondition getPage(BoundSql boundSql) {
PageCondition page = null;
// 获取执行方法的参数
Object param = boundSql.getParameterObject();
if (param instanceof Map) {
Map paramMap = (Map) param;
for (Iterator iterator = paramMap.values().iterator(); iterator.hasNext(); ) {
Object obj = iterator.next();
if (obj instanceof PageCondition) {
page = (PageCondition) obj;
break;
}
}
} else if (param instanceof PageCondition) {
page = (PageCondition) param;
}
return page;
}
/**
* 这种方式比较简陋,使用时注意sql语句末的分号问题
*/
private String getCountSql(String originSql) {
return new StringBuilder().append("select count(*) from (").append(originSql).append(") t").toString();
}
public Object plugin(Object o) {
return Plugin.wrap(o, this);
}
public void setProperties(Properties properties) {
// 这里可以获取到拦截器配置的properties
}
}
以上是分页拦截器简单的代码实现,它有什么作用呢,我们知道,我们在写分页sql的时候,我们首先要计算出满足查询条件的数据总数用于计算总页数,如果我们使用limit x offset y的写法,还需要计算出本次查询的偏移量,而使用这个分页拦截器,就可以省去这些步骤的编写,由分页拦截器帮助我们完成,这样就省去了不少开发工作量。示例中的写法只适用于MySQL分页的写法,其他数据库分页的写法可能与MySQL的写法不同,例如Oracle,我们可以为示例的分页拦截器扩展支持多方言,当然这不是本文的重点,也比较简单,有兴趣的读者可以自行研究实现。下面我们来看一下这个分页拦截器要怎么使用。
首先,Mybatis配置文件中增加拦截器的配置:
<plugins>
<plugin interceptor="com.jd.plugin.dao.aop.PageInterceptor"/>
</plugins>
mapper:
public interface UserMapper {
/**
* 分页条件查询
*/
List<User> listByCondition(@Param("page") PageCondition page, @Param("query") UserQuery query);
}
mapper配置文件:
<select id="listByCondition" parameterType="UserQuery" resultType="User">
SELECT
id, name, gender, age
FROM user
<where>
<if test="query.id != null and query.id > 0">
AND id = #{query.id}
<if/>
<!-- 其他条件... -->
<where/>
ORDER BY id
</select>
调用UserMapper的listByCondtion方法传入PageCondtion条件(需要传入当前页currentPage和每页的数量pageSize)和其他我们需要的查询条件就可以完成分页功能,这样我们在mapper配置文件中就可以不用写分页语句,只关注我们查询sql自身的编写即可,分页拦截器会自动完成分页语句,并计算总数和总页数放入PageCondition对象中。
源码分析
下面我们正式开始Mybatis拦截器源码的分析,拦截器配置解析和调用入口的源码我们已经在之前文章中分析过,拦截器会被保存在Configuration对象中维护的一个InterceptorChain对象中,调用InterceptorChain的pluginAll方法来实现拦截器的调用:
public Object pluginAll(Object target) {
for (Interceptor interceptor : interceptors) {
/* 调用拦截器的plugin方法 */
target = interceptor.plugin(target);
}
return target;
}
我们上文介绍的分页拦截器PageInterceptor实现了Interceptor接口,我们来看它的plugin方法:
public Object plugin(Object o) {
/* 包装目标对象 */
return Plugin.wrap(o, this);
}
这里的目标对象是什么呢?读者可以思考一下,我们的分页拦截器是基于修改原sql实现的,我们要为原sql拼接分页语句,我们什么时候修改原sql最合适呢,如果读者熟悉JDBC的API,肯定一下子就想到,在创建PreparedStatement之前,因为创建PreparedStatement需要指定sql语句,所以在创建之前修改最合适,Mybatis什么时候创建PreparedStatement呢,这个我们在之前的文章中已经分析过,在构建StatementHandler之后,调用它的prepare完成PreparedStatement的创建,所以这里的目标对象就是StatementHandler对象,而我们需要拦截的目标方法就是prepare方法。
Plugin:
public static Object wrap(Object target, Interceptor interceptor) {
/* 获取注解配置 */
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
Class<?> type = target.getClass();
// 获取目标类的所有接口与signatureMap的key集合的交集
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
if (interfaces.length > 0) {
// 创建代理对象
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(target, interceptor, signatureMap));
}
return target;
}
Plugin:
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
// 获取拦截器的@Intercepts注解
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
// 没有注解@Intercepts抛出异常
if (interceptsAnnotation == null) {
throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
// 获取@Intercepts注解的value属性值,也就是@Signature注解数组
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
for (Signature sig : sigs) {
Set<Method> methods = signatureMap.get(sig.type());
if (methods == null) {
methods = new HashSet<Method>();
signatureMap.put(sig.type(), methods);
}
try {
// type我们指定是的StatementHandler接口,method我们指定的是prepare方法,args我们指定的是数据库连接Connection(StatementHandler的prepare方法的第一个参数就是Connection)
Method method = sig.type().getMethod(sig.method(), sig.args());
methods.add(method);
} catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
}
}
return signatureMap;
}
因为创建代理对象时使用的是JDK动态代理,Plugin类实现了InvocationHandler,所以执行时我们需要分析它的invoke方法:
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
if (methods != null && methods.contains(method)) {
// 如果拦截器配置的方法与目标方法相匹配,则调用拦截器的intercept方法
return interceptor.intercept(new Invocation(target, method, args));
}
// 不匹配的直接执行目标方法
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
分析到这里,我们再回头看我们的分页拦截器对intercept方法实现,相信读者已经可以清楚的理解方法的内容了。到这里,整个Mybatis拦截器的源码分析就完成了。