Mybatis源码解析之拦截器篇

阅读须知

  • Mybatis源码版本:3.4.4
  • 文章中使用/* */注释的方法会做深入分析

分页拦截器DEMO

本篇文章我们来分析Mybatis拦截器的源码,进入源码分析之前,我们先来看一个Mybatis拦截器实际应用的小例子 — 分页拦截器:

/**
 * 分页查询对象,走分页拦截器时使用
 */
public class PageCondition {

    private int totalCount;
    private int totalPage;
    private int currentPage;
    private int pageSize;

    public PageCondition() {

    }

    public PageCondition(int currentPage, int pageSize) {
        this.currentPage =  transferCurrentPage(currentPage);
        this.pageSize = pageSize;
    }

    /**
     * 转换当前页,页面当前页1开始,sql当前页0开始
     */
    private int transferCurrentPage(int currentPage){
        if(currentPage > 0){
            currentPage -= 1;
        }
        return currentPage;
    }

    /**
     * 恢复currentPage
     */
    public void recoveryCurrentPage(){
        currentPage += 1;
    }

    public int getTotalCount() {
        return totalCount;
    }

    public void setTotalCount(int totalCount) {
        this.totalCount = totalCount;
        this.totalPage = this.totalCount % this.pageSize == 0 ? this.totalCount / this.pageSize : this.totalCount / this.pageSize + 1;
    }

    public int getTotalPage() {
        return totalPage;
    }

    public int getCurrentPage() {
        return currentPage;
    }

    public void setCurrentPage(int currentPage) {
        this.currentPage = transferCurrentPage(currentPage);
    }

    public int getPageSize() {
        return pageSize;
    }

    public void setPageSize(int pageSize) {
        this.pageSize = pageSize;
    }

    public void setTotalPage(int totalPage) {
        this.totalPage = totalPage;
    }

    @Override
    public String toString() {
        return "PageCondition{" +
                "totalCount=" + totalCount +
                ", totalPage=" + totalPage +
                ", currentPage=" + currentPage +
                ", pageSize=" + pageSize +
                '}';
    }
}

/**
 * 分页拦截器
 */
@Intercepts(@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class}))
public class PageInterceptor implements Interceptor {

    private static final Logger logger = LoggerFactory.getLogger(PageInterceptor.class);

    private static final String BOUND_SQL_KEY = "delegate.boundSql.sql";

    private static final String PARAMETER_HANDLER_KEY = "delegate.parameterHandler";

    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        BoundSql boundSql = statementHandler.getBoundSql();
        PageCondition page = getPage(boundSql);
        if (page == null) {
            return invocation.proceed();
        }
        String originSql = boundSql.getSql();
        String pageSql = getPageSql(originSql, page);
        // 反射修改原sql
        metaObject.setValue(BOUND_SQL_KEY, pageSql);
        setPageTotalCount(originSql, metaObject, invocation, page);
        return invocation.proceed();
    }

    /**
     * 设置分页对象的总数
     */
    private void setPageTotalCount(String originSql, MetaObject metaObject, Invocation invocation, PageCondition page) throws Exception {
        PreparedStatement countStatement = null;
        ResultSet resultSet = null;
        try {
            String countSql = getCountSql(originSql);
            // 这个数据库连接不能关,后面的分页sql也要使用这个connection
            Connection connection = (Connection) invocation.getArgs()[0];
            countStatement = connection.prepareStatement(countSql);
            ParameterHandler parameterHandler = (ParameterHandler) metaObject.getValue(PARAMETER_HANDLER_KEY);
            parameterHandler.setParameters(countStatement);
            resultSet = countStatement.executeQuery();
            int totalCount = 0;
            if (resultSet.next()) {
                totalCount = resultSet.getInt(1);
            }
            page.setTotalCount(totalCount);
        } finally {
            if (resultSet != null) {
                resultSet.close();
            }
            if (countStatement != null) {
                countStatement.close();
            }
        }
    }

    /**
     * 这种方式比较简陋,使用时注意sql语句末的分号问题
     */
    private String getPageSql(String originSql, PageCondition page) {
        return new StringBuilder().append(originSql).append(" LIMIT ").
                append(page.getCurrentPage() * page.getPageSize()).append(",").append(page.getPageSize()).toString();
    }

    /**
     * 获取分页对象
     */
    private PageCondition getPage(BoundSql boundSql) {
        PageCondition page = null;
        // 获取执行方法的参数
        Object param = boundSql.getParameterObject();
        if (param instanceof Map) {
            Map paramMap = (Map) param;
            for (Iterator iterator = paramMap.values().iterator(); iterator.hasNext(); ) {
                Object obj = iterator.next();
                if (obj instanceof PageCondition) {
                    page = (PageCondition) obj;
                    break;
                }
            }
        } else if (param instanceof PageCondition) {
            page = (PageCondition) param;
        }
        return page;
    }

    /**
     * 这种方式比较简陋,使用时注意sql语句末的分号问题
     */
    private String getCountSql(String originSql) {
        return new StringBuilder().append("select count(*) from (").append(originSql).append(") t").toString();
    }

    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    public void setProperties(Properties properties) {
        // 这里可以获取到拦截器配置的properties
    }
}

以上是分页拦截器简单的代码实现,它有什么作用呢,我们知道,我们在写分页sql的时候,我们首先要计算出满足查询条件的数据总数用于计算总页数,如果我们使用limit x offset y的写法,还需要计算出本次查询的偏移量,而使用这个分页拦截器,就可以省去这些步骤的编写,由分页拦截器帮助我们完成,这样就省去了不少开发工作量。示例中的写法只适用于MySQL分页的写法,其他数据库分页的写法可能与MySQL的写法不同,例如Oracle,我们可以为示例的分页拦截器扩展支持多方言,当然这不是本文的重点,也比较简单,有兴趣的读者可以自行研究实现。下面我们来看一下这个分页拦截器要怎么使用。

首先,Mybatis配置文件中增加拦截器的配置:

<plugins>
    <plugin interceptor="com.jd.plugin.dao.aop.PageInterceptor"/>
</plugins>

mapper:

public interface UserMapper {
    /**
     * 分页条件查询
     */
    List<User> listByCondition(@Param("page") PageCondition page, @Param("query") UserQuery query);
}

mapper配置文件:

<select id="listByCondition" parameterType="UserQuery" resultType="User">
    SELECT
        id, name, gender, age
    FROM user 
    <where>
        <if test="query.id != null and query.id > 0">
            AND id = #{query.id}
        <if/>
        <!-- 其他条件... -->
    <where/>
    ORDER BY id
</select>

调用UserMapper的listByCondtion方法传入PageCondtion条件(需要传入当前页currentPage和每页的数量pageSize)和其他我们需要的查询条件就可以完成分页功能,这样我们在mapper配置文件中就可以不用写分页语句,只关注我们查询sql自身的编写即可,分页拦截器会自动完成分页语句,并计算总数和总页数放入PageCondition对象中。

源码分析

下面我们正式开始Mybatis拦截器源码的分析,拦截器配置解析和调用入口的源码我们已经在之前文章中分析过,拦截器会被保存在Configuration对象中维护的一个InterceptorChain对象中,调用InterceptorChain的pluginAll方法来实现拦截器的调用:

public Object pluginAll(Object target) {
	for (Interceptor interceptor : interceptors) {
		/* 调用拦截器的plugin方法 */
		target = interceptor.plugin(target);
	}
	return target;
}

我们上文介绍的分页拦截器PageInterceptor实现了Interceptor接口,我们来看它的plugin方法:

public Object plugin(Object o) {
	/* 包装目标对象 */
	return Plugin.wrap(o, this);
}

这里的目标对象是什么呢?读者可以思考一下,我们的分页拦截器是基于修改原sql实现的,我们要为原sql拼接分页语句,我们什么时候修改原sql最合适呢,如果读者熟悉JDBC的API,肯定一下子就想到,在创建PreparedStatement之前,因为创建PreparedStatement需要指定sql语句,所以在创建之前修改最合适,Mybatis什么时候创建PreparedStatement呢,这个我们在之前的文章中已经分析过,在构建StatementHandler之后,调用它的prepare完成PreparedStatement的创建,所以这里的目标对象就是StatementHandler对象,而我们需要拦截的目标方法就是prepare方法。
Plugin:

public static Object wrap(Object target, Interceptor interceptor) {
	/* 获取注解配置 */
	Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
	Class<?> type = target.getClass();
	// 获取目标类的所有接口与signatureMap的key集合的交集
	Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
	if (interfaces.length > 0) {
		// 创建代理对象
		return Proxy.newProxyInstance(
			type.getClassLoader(),
			interfaces,
			new Plugin(target, interceptor, signatureMap));
	}
	return target;
}

Plugin:

private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
	// 获取拦截器的@Intercepts注解
	Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
	// 没有注解@Intercepts抛出异常
	if (interceptsAnnotation == null) {
		throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());      
	}
	// 获取@Intercepts注解的value属性值,也就是@Signature注解数组
	Signature[] sigs = interceptsAnnotation.value();
	Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
	for (Signature sig : sigs) {
		Set<Method> methods = signatureMap.get(sig.type());
		if (methods == null) {
			methods = new HashSet<Method>();
			signatureMap.put(sig.type(), methods);
		}
		try {
			// type我们指定是的StatementHandler接口,method我们指定的是prepare方法,args我们指定的是数据库连接Connection(StatementHandler的prepare方法的第一个参数就是Connection)
			Method method = sig.type().getMethod(sig.method(), sig.args());
			methods.add(method);
		} catch (NoSuchMethodException e) {
			throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
		}
	}
	return signatureMap;
}

因为创建代理对象时使用的是JDK动态代理,Plugin类实现了InvocationHandler,所以执行时我们需要分析它的invoke方法:

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
	try {
		Set<Method> methods = signatureMap.get(method.getDeclaringClass());
		if (methods != null && methods.contains(method)) {
			// 如果拦截器配置的方法与目标方法相匹配,则调用拦截器的intercept方法
			return interceptor.intercept(new Invocation(target, method, args));
		}
		// 不匹配的直接执行目标方法
		return method.invoke(target, args);
	} catch (Exception e) {
		throw ExceptionUtil.unwrapThrowable(e);
	}
}

分析到这里,我们再回头看我们的分页拦截器对intercept方法实现,相信读者已经可以清楚的理解方法的内容了。到这里,整个Mybatis拦截器的源码分析就完成了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值