Mybatis中插件的实现原理

// 这是一个插件的实现原理案例
public class InterceptorDemo {

    public static void main(String[] args) {
        // 创建SuperMan的代理对象,通过FlyInterceptor对它进行增强
        Flyable flyable = (Flyable) Plugin.wrap(new SuperMan(), new FlyInterceptor());
        flyable.fly();
    }

    @Intercepts({
            @Signature(type = Flyable.class, method = "fly", args = {})
    })
    static class FlyInterceptor implements Interceptor {
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
            System.out.println("begin fly");
            Object target = invocation.getTarget();
            if (target instanceof Flyable) {
                Method method = invocation.getMethod();
                method.invoke(target);
            }
            System.out.println("stop fly");
            return null;
        }
    }

    interface Flyable {
        void fly();
    }

    static class SuperMan implements Flyable {

        @Override
        public void fly() {
            System.out.println("flying");
        }
    }
}

// 自定义一个慢SQL查询插件,拦截StatementHandler的查询方法
@Intercepts({
        @Signature(type = StatementHandler.class, method = "update", args = {Statement.class}),
        @Signature(type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class})
})
@Slf4j
public class SlowSqlInterceptor implements Interceptor {
    private final long slowSqlTime = 10;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        String sql = boundSql.getSql();
        long startTime = System.currentTimeMillis();
        Object result = invocation.proceed();
        long endTime = System.currentTimeMillis();
        long executeTime = endTime - startTime;
        if (executeTime > slowSqlTime) {
            log.debug("{}", sql);
            log.warn("slow sql,execute time:{}ms", executeTime);
        }
        return result;
    }
}

// 拦截mybatis四大组件的类
public interface Interceptor {

    // 拦截方法
    Object intercept(Invocation invocation) throws Throwable;

    /**
     * 将目标对象包装,根据目当前拦截器中的注解信息来给目标对象创建代理对象
     * 目标对象的类型为: ParameterHandler ResultSetHandler StatementHandler Executor
     */
    default Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    // 给拦截器提供的属性
    default void setProperties(Properties properties) {
    }

    // 方法调用器
    public class Invocation {
        // 目标对象
        private final Object target;
        // 目标方法
        private final Method method;
        // 方法参数
        private final Object[] args;

        public Object proceed() throws InvocationTargetException, IllegalAccessException {
            return method.invoke(target, args);
        }

    }

}

// 插件的工具类
public class Plugin implements InvocationHandler {
    // 包装的目标对象
    private final Object target;
    // 拦截器对象
    private final Interceptor interceptor;
    // 所有方法的签名信息
    private final Map<Class<?>, Set<Method>> signatureMap;


    // 使用指定的拦截器增强目标对象
    public static Object wrap(Object target, Interceptor interceptor) {
        // 获取拦截器中标注的方法签名信息
        // 获取@Intercepts注解中标注的方法签名信息(具体拦截的方法信息)
        Map<Class<?>, Set<Method>> signatureMap = this.getSignatureMap(interceptor);
        // 获取目标类型
        Class<?> type = target.getClass();
        // 获取增强的目标类实现的接口,如果方法签名中包含这些接口,表示要给这些接口生成代理对象
        Class<?>[] interfaces = this.getAllInterfaces(type, signatureMap);
        // 生成代理对象
        if (interfaces.length > 0) {
            return Proxy.newProxyInstance(type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap));
        }
        // 如果不存在接口,则不需要创建代理对象
        return target;
    }

    // 代理对象需要执行的方法,当前类是一个InvocationHandler
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        try {
            // 获取该目标类要增强的所有方法
            Set<Method> methods = signatureMap.get(method.getDeclaringClass());
            // 如果需要增强的方法中包含当前正在执行的方法
            if (methods != null && methods.contains(method)) {
                // 就要对当前方法进行增强拦截
                return interceptor.intercept(new Invocation(target, method, args));
            }
            // 如果当前执行的方法不需要增强,直接直接方法
            return method.invoke(target, args);
        } catch (Exception e) {
            throw ExceptionUtil.unwrapThrowable(e);
        }
    }

    // 获取@Intercepts注解中标注的方法签名信息(具体拦截的方法信息)
    private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
        // 获取拦截器中的注解信息
        Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
        // 拦截器中必须包含@Intercepts注解
        if (interceptsAnnotation == null) {
            throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
        }
        // 获取注解中的签名信息
        Signature[] sigs = interceptsAnnotation.value();
        Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
        // 遍历所有方法签名
        for (Signature sig : sigs) {
            // 为什么要用Set集合,因为@Signature注解可以重复,需要拦截相同类的多个方法的情况就需要Set
            Set<Method> methods = MapUtil.computeIfAbsent(signatureMap, sig.type(), k -> new HashSet<>());
            try {
                // 获取到具体的方法对象
                Method method = sig.type().getMethod(sig.method(), sig.args());
                // 保存拦截的方法对象
                methods.add(method);
            }
            // 如果该对象中不存在指定的方法,抛出异常
            catch (NoSuchMethodException e) {
                throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
            }
        }
        return signatureMap;
    }

    // 获取增强的目标类实现的接口,如果方法签名中包含这些接口,表示要给这些接口生成代理对象
    private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
        Set<Class<?>> interfaces = new HashSet<>();
        while (type != null) {
            // 获取需要拦截的目标类实现的所有接口
            for (Class<?> c : type.getInterfaces()) {
                // 在方法签名Map中存在这个接口(表示要对这个接口增强)
                if (signatureMap.containsKey(c)) {
                    // 将该接口保存下来,用给它生成代理对象
                    interfaces.add(c);
                }
            }
            // 继续找父接口
            type = type.getSuperclass();
        }
        return interfaces.toArray(new Class<?>[0]);
    }

    // 方法签名注解,就是用于描述一个具体的方法
    public @interface Signature {
        // 拦截器的目标类
        Class<?> type();

        // 拦截的目标类中的方法
        String method();

        // 拦截的方法的参数类型
        Class<?>[] args();
    }
}

// 拦截器,执行链
public class InterceptorChain {
    // 所有的拦截器
    public List<Interceptor> interceptors = new ArrayList<>();

    // 执行所有拦截器,最终返回方法执行结果
    public Object pluginAll(Object target) {
        for (Interceptor interceptor : interceptors) {
            // 将目标方法进行增强,创建目标对象的代理对象
            // 通过代理对象进行拦截目标方法,符合条件则执行拦截器
            // 执行目标方法,返回最终结果
            target = interceptor.plugin(target);
        }
        // 执行的结果
        return target;
    }

    // 添加拦截器
    public void addInterceptor(Interceptor interceptor) {
        interceptors.add(interceptor);
    }

    // 获取所有拦截器
    public List<Interceptor> getInterceptors() {
        return Collections.unmodifiableList(interceptors);
    }

}

// 核心配置类组件
public class Configuration {
    // 拦截器的执行链
    public InterceptorChain interceptorChain = new InterceptorChain();

    // 由下面的方法可知,插件可以拦截的对象由四个
    // 就是Mybatis的四大组件,ParameterHandler,ResultSetHandler,StatementHandler,Executor
    // 创建Parameter参数处理器
    public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
        ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
        // 执行所有拦截器
        return (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    }

    // 创建ResultSet结果集处理器
    public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler, ResultHandler resultHandler, BoundSql boundSql) {
        ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
        // 执行所有拦截器
        return (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    }

    // 创建Statement处理器
    public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
        // 执行所有拦截器
        return (StatementHandler) interceptorChain.pluginAll(statementHandler);
    }

    // 创建默认的执行器
    public Executor newExecutor(Transaction transaction) {
        return newExecutor(transaction, defaultExecutorType);
    }

    // 创建指定类型的执行器
    public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
        executorType = executorType == null ? defaultExecutorType : executorType;
        Executor executor;
        if (ExecutorType.BATCH == executorType) {
            executor = new Executor.BatchExecutor(this, transaction);
        } else if (ExecutorType.REUSE == executorType) {
            executor = new Executor.ReuseExecutor(this, transaction);
        } else {
            executor = new Executor.SimpleExecutor(this, transaction);
        }
        // 如果开启了二级缓存,使用了装饰器模式
        if (cacheEnabled) {
            executor = new Executor.CachingExecutor(executor);
        }
        // 执行所有拦截器
        return (Executor) interceptorChain.pluginAll(executor);
    }
}

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值