Mybatis源码(4)-拦截器

一、介绍

如拦截器的字面意思,拦截器可以拦截请求,修改函数入参、返回值等。Mybatis通过Interceptor供程序员实现拦截请求,实现分页等相关自定义功能。

Interceptor的定义如下:

public interface Interceptor {

  // invocation封装了被拦截类、被拦截方法、方法入参
  Object intercept(Invocation invocation) throws Throwable;

  default Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }

  default void setProperties(Properties properties) {
    // NOP
  }

}

下图介绍了mybatis查询的核心流程:

 Executor.query()主要分为四步:

  1. 调用statementHandler构建Statement并完成初始化;
  2. ParameterHandler完成给Statement的参数赋值;
  3. 查询db;
  4. 将sql执行的原生结果集ResultSet转换成用户自定义的类型。

mybatis的拦截器支持对Executor方法的执行以及后续的第1、2、4步进行拦截,具体见org.apache.ibatis.session.Configuration:

public class Configuration {

   public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
    ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
    parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    return parameterHandler;
  }

  public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
      ResultHandler resultHandler, BoundSql boundSql) {
    ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
    resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    return resultSetHandler;
  }

  public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
    statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
    return statementHandler;
  }
}

  public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    executorType = executorType == null ? defaultExecutorType : executorType;
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    if (ExecutorType.BATCH == executorType) {
      executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
      executor = new ReuseExecutor(this, transaction);
    } else {
      executor = new SimpleExecutor(this, transaction);
    }
    if (cacheEnabled) {
      executor = new CachingExecutor(executor);
    }
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
  }

二、执行流程分析

Plugin: 生成用配置的Interceptor代理Target的对象

public class InterceptorChain {

  private final List<Interceptor> interceptors = new ArrayList<>();

  public Object pluginAll(Object target) { // 对目标对象进行包装
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }

}

Interceptor: 拦截器,用户需要实现其interceptor()。

public interface Interceptor {

  Object intercept(Invocation invocation) throws Throwable;

  default Object plugin(Object target) {  // 通过代理的方式包装目标对象
    return Plugin.wrap(target, this);
  }

  default void setProperties(Properties properties) {
    // NOP
  }

}

Plugin:实现了InvocationHandler,因此mybatis拦截器的实现使用了动态代理的方式代理目标类的执行。

  public static Object wrap(Object target, Interceptor interceptor) {
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor); // 解析intercepor的@Intercepts注解,key为声明的被代理的接口,value是被代理接口需要被代理的方法列表
    Class<?> type = target.getClass();
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap); // 根据目标类的类型(父类型),获取可以被interceptor代理的接口
    if (interfaces.length > 0) {
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

下面分析Plugin.invoke():

  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      if (methods != null && methods.contains(method)) {
        return interceptor.intercept(new Invocation(target, method, args)); // 调用拦截器
      }
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }

至此Mybatis拦截器的执行流程分析完毕,下面贴一个比较简单的Mybatis分页插件demo:

package com.rango.interceptor;

import com.rango.dto.PageRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Map;

@Slf4j
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class, Integer.class }) })
public class PageInterceptor implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        PageRequest pageRequest = getPageRequest(boundSql);
        if (pageRequest == null) { // 如果不是分页请求,则往下继续执行
            return invocation.proceed();
        }
        setCount2PageRequest(invocation, boundSql.getSql(), pageRequest); // 查询count,将count写入到入参pageRequest中
        return executePageQuery(statementHandler, invocation, pageRequest); // 改写originalSql -> pageSql
    }

    private Object executePageQuery(StatementHandler statementHandler, Invocation invocation, PageRequest pageRequest) throws InvocationTargetException, IllegalAccessException {
        String pageSql = buildPageSql(statementHandler.getBoundSql().getSql(), pageRequest);

        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        metaObject.setValue("boundSql.sql", pageSql);
        return invocation.proceed();
    }

    private String buildPageSql(String originalSql, PageRequest pageRequest) {
        return new StringBuilder().append(originalSql).append(" limit ")
                .append(pageRequest.getPageSize() * (pageRequest.getPageNo() - 1))
                .append(",").append(pageRequest.getPageSize())
                .toString();
    }

    private void setCount2PageRequest(Invocation invocation, String originalSql, PageRequest pageRequest) throws SQLException {
        int count = getCount(invocation, originalSql);
        pageRequest.setTotal(count);
        pageRequest.setTotalPage(count / pageRequest.getPageSize() + 1);
    }

    private int getCount(Invocation invocation, String originSql) throws SQLException {
        Connection connection = (Connection) invocation.getArgs()[0];
        PreparedStatement ps = null;
        ResultSet resultSet = null;
        try {
            StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
            MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
            ParameterHandler parameterHandler = (ParameterHandler) metaObject.getValue("parameterHandler");
            ps = connection.prepareStatement(buildCountSql(originSql));
            parameterHandler.setParameters(ps);
            resultSet = ps.executeQuery();
            if (resultSet == null) {
                return 0;
            }
            if (resultSet.next()) {
                return resultSet.getInt(1);
            }
        } catch (Exception e) {

        } finally {
            if (ps != null) {
                ps.close();
            }
            if (resultSet != null) {
                resultSet.close();
            }
        }
        return 0;
    }

    private String buildCountSql(String originalSql) {
        return "select count(*) from (" + originalSql + " ) as temp";
    }

    private PageRequest getPageRequest(BoundSql boundSql) {
        Object parameterObject = boundSql.getParameterObject();
        if (parameterObject instanceof Map) {
            Map paramMap = (Map) parameterObject;
            for (Object value : paramMap.values()) {
                if (value instanceof PageRequest) {
                    return (PageRequest) value;
                }
            }
        } else if (parameterObject instanceof PageRequest) {
            return (PageRequest) parameterObject;
        }
        return null;
    }
}
package com.rango.dto;

import lombok.Data;

@Data
public class PageRequest {

    private int pageNo = 1;

    private int pageSize = 20;

    private int total;     // 总记录数

    private int totalPage; // 总页数

    public PageRequest() {
    }

    public PageRequest(int pageNo, int pageSize) {
        this.pageNo = pageNo;
        this.pageSize = pageSize;
    }

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值