第 7 章 MybatisPlus 插件

第 7 章 MybatisPlus 插件

1、插件机制概述

MybatisPlusInterceptor 核心插件

MybatisPlus 通过插件(Interceptor)可以做到拦截四大对象(ExecutorStatementHandleParameterHandlerResultSetHandler)相关方法的执行,根据需求完成相关数据的动态改变。

MybatisPlusInterceptor 插件是核心插件, 目前代理了 Executor#queryExecutor#updateStatementHandler#prepare 方法

MybatisPlusInterceptor 中有 private List<InnerInterceptor> interceptors = new ArrayList<>(); 字段,该字段(插件集合)用于存储用户注册的内置拦截器插件

MybatisPlusInterceptor#addInnerInterceptor() 方法用于添加拦截器插件;MybatisPlusInterceptor#plugin() 方法用于创建目标对象的代理对象,wrap 为包裹的意思,就是将目标对象包裹为代理对象返回给调用者;MybatisPlusInterceptor#intercept() 就是拦截目标方法,执行内置拦截器的方法

MybatisPlusInterceptor#intercept() 方法主要代理了三种类型的操作:

  1. 如果是查询操作,则遍历所有内置拦截器,先执行 !query.willDoQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql 判断是否需要执行代理方法
    1. 如果 query.willDoQuery() 方法返回 false,则表示无需执行拦截器代理,直接返回 Collections.emptyList()
    2. 否则表示需要执行拦截器代理,对于查询操作而言,执行 query.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);
  2. 如果是更新操作,则遍历所有内置拦截器,先执行 !update.willDoUpdate(executor, ms, parameter) 判断是否需要执行代理方法
    1. 如果 query.willDoQuery() 方法返回 false,则表示无需执行拦截器代理,直接返回 -1
    2. 否则表示需要执行拦截器代理,对于更新操作而言,执行 update.beforeUpdate(executor, ms, parameter);
  3. 否则就是 StatementHandler 的处理,遍历所有内置拦拦截器,执行 innerInterceptor.beforePrepare(sh, connections, transactionTimeout); 方法
  4. 最后统一执行 return invocation.proceed(); 获取拦截方法的返回值,返回给调用者
/**
 * @author miemie
 * @since 3.4.0
 */
@SuppressWarnings({"rawtypes"})
@Intercepts(
    {
        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}),
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
    }
)
public class MybatisPlusInterceptor implements Interceptor {

    @Setter
    private List<InnerInterceptor> interceptors = new ArrayList<>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        Object[] args = invocation.getArgs();
        if (target instanceof Executor) {
            final Executor executor = (Executor) target;
            Object parameter = args[1];
            boolean isUpdate = args.length == 2;
            MappedStatement ms = (MappedStatement) args[0];
            if (!isUpdate && ms.getSqlCommandType() == SqlCommandType.SELECT) {
                RowBounds rowBounds = (RowBounds) args[2];
                ResultHandler resultHandler = (ResultHandler) args[3];
                BoundSql boundSql;
                if (args.length == 4) {
                    boundSql = ms.getBoundSql(parameter);
                } else {
                    // 几乎不可能走进这里面,除非使用Executor的代理对象调用query[args[6]]
                    boundSql = (BoundSql) args[5];
                }
                for (InnerInterceptor query : interceptors) {
                    if (!query.willDoQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql)) {
                        return Collections.emptyList();
                    }
                    query.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);
                }
                CacheKey cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
                return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
            } else if (isUpdate) {
                for (InnerInterceptor update : interceptors) {
                    if (!update.willDoUpdate(executor, ms, parameter)) {
                        return -1;
                    }
                    update.beforeUpdate(executor, ms, parameter);
                }
            }
        } else {
            // StatementHandler
            final StatementHandler sh = (StatementHandler) target;
            Connection connections = (Connection) args[0];
            Integer transactionTimeout = (Integer) args[1];
            for (InnerInterceptor innerInterceptor : interceptors) {
                innerInterceptor.beforePrepare(sh, connections, transactionTimeout);
            }
        }
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor || target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    public void addInnerInterceptor(InnerInterceptor innerInterceptor) {
        this.interceptors.add(innerInterceptor);
    }

    public List<InnerInterceptor> getInterceptors() {
        return Collections.unmodifiableList(interceptors);
    }

    /**
     * 使用内部规则,拿分页插件举个栗子:
     * <p>
     * - key: "@page" ,value: "com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor"
     * - key: "page:limit" ,value: "100"
     * <p>
     * 解读1: key 以 "@" 开头定义了这是一个需要组装的 `InnerInterceptor`, 以 "page" 结尾表示别名
     * value 是 `InnerInterceptor` 的具体的 class 全名
     * 解读2: key 以上面定义的 "别名 + ':'" 开头指这个 `value` 是定义的该 `InnerInterceptor` 属性需要设置的值
     * <p>
     * 如果这个 `InnerInterceptor` 不需要配置属性也要加别名
     */
    @Override
    public void setProperties(Properties properties) {
        PropertyMapper pm = PropertyMapper.newInstance(properties);
        Map<String, Properties> group = pm.group(StringPool.AT);
        group.forEach((k, v) -> {
            InnerInterceptor innerInterceptor = ClassUtils.newInstance(k);
            innerInterceptor.setProperties(v);
            addInnerInterceptor(innerInterceptor);
        });
    }
}

Plugin#wrap() 方法传入两个参数:Object target 为目标对象;Interceptor interceptor 为拦截器插件对象。wrap() 方法将这两个对象揉在一起,创建好代理对象返回给我们

Plugin#invoke() 方法传入三个参数:Object proxy 为代理对象;Method method 要拦截的目标方法;Object[] args 为方法参数。invoke() 方法为我们代理目标方法,并将执行结果返回给我们

/**
 * @author Clinton Begin
 */
public class Plugin implements InvocationHandler {

  private final Object target;
  private final Interceptor interceptor;
  private final Map<Class<?>, Set<Method>> signatureMap;

  private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
    this.target = target;
    this.interceptor = interceptor;
    this.signatureMap = signatureMap;
  }

  public static Object wrap(Object target, Interceptor interceptor) {
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    Class<?> type = target.getClass();
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      if (methods != null && methods.contains(method)) {
        return interceptor.intercept(new Invocation(target, method, args));
      }
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }

创建四大对象时,都会调用 interceptorChain.pluginAll() 方法,该方法内部会遍历所有 Interceptor 插件,并执行其 interceptor.plugin(target) 方法,方法,目的是为当前的四大对象创建代理。

经过处理后,代理对象就可以拦截到四大对象相关方法的执行,因为要执行四大对象的方法需要先经过代理对象

/**
 * @author Clinton Begin
 */
public class InterceptorChain {

  private final List<Interceptor> interceptors = new ArrayList<>();

  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }

}

InnerInterceptor

MybatisPlus 提供的插件都将基于 InnerInterceptor 接口来实现功能,目前已有的功能:

  • 自动分页: PaginationInnerInterceptor
  • 多租户: TenantLineInnerInterceptor
  • 动态表名: DynamicTableNameInnerInterceptor
  • 乐观锁: OptimisticLockerInnerInterceptor
  • sql性能规范: IllegalSQLInnerInterceptor
  • 防止全表更新与删除: BlockAttackInnerInterceptor
**
 * @author miemie
 * @since 3.4.0
 */
@SuppressWarnings({"rawtypes"})
public interface InnerInterceptor {

    /**
     * 判断是否执行 {@link Executor#query(MappedStatement, Object, RowBounds, ResultHandler, CacheKey, BoundSql)}
     * <p>
     * 如果不执行query操作,则返回 {@link Collections#emptyList()}
     *
     * @param executor      Executor(可能是代理对象)
     * @param ms            MappedStatement
     * @param parameter     parameter
     * @param rowBounds     rowBounds
     * @param resultHandler resultHandler
     * @param boundSql      boundSql
     * @return 新的 boundSql
     */
    default boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        return true;
    }

    /**
     * {@link Executor#query(MappedStatement, Object, RowBounds, ResultHandler, CacheKey, BoundSql)} 操作前置处理
     * <p>
     * 改改sql啥的
     *
     * @param executor      Executor(可能是代理对象)
     * @param ms            MappedStatement
     * @param parameter     parameter
     * @param rowBounds     rowBounds
     * @param resultHandler resultHandler
     * @param boundSql      boundSql
     */
    default void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        // do nothing
    }

    /**
     * 判断是否执行 {@link Executor#update(MappedStatement, Object)}
     * <p>
     * 如果不执行update操作,则影响行数的值为 -1
     *
     * @param executor  Executor(可能是代理对象)
     * @param ms        MappedStatement
     * @param parameter parameter
     */
    default boolean willDoUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        return true;
    }

    /**
     * {@link Executor#update(MappedStatement, Object)} 操作前置处理
     * <p>
     * 改改sql啥的
     *
     * @param executor  Executor(可能是代理对象)
     * @param ms        MappedStatement
     * @param parameter parameter
     */
    default void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        // do nothing
    }

    /**
     * {@link StatementHandler#prepare(Connection, Integer)} 操作前置处理
     * <p>
     * 改改sql啥的
     *
     * @param sh                 StatementHandler(可能是代理对象)
     * @param connection         Connection
     * @param transactionTimeout transactionTimeout
     */
    default void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        // do nothing
    }

    default void setProperties(Properties properties) {
        // do nothing
    }
}

2、自动分页插件

自动分页插件:PaginationInnerInterceptor

在未添加分页插件的情况下,我们使用 Page 执行分页和排序操作

/**
 * @Author Oneby
 * @Date 2021/4/18 17:54
 */
@RunWith(SpringRunner.class)
@SpringBootTest
public class MybatisPlusPluginTest {

    @Autowired
    private UserMapper userMapper;

    @Test
    public void testPaginationInnerInterceptor() {
        Page<User> page = new Page<>(1,3);
        page.addOrder(OrderItem.asc("age"));
        Page<User> userPage = userMapper.selectPage(page, null);
        for (User user : userPage.getRecords()) {
            System.out.println(user);
        }
    }

}

SQL 日志:分页与排序并没有生效

Creating a new SqlSession
SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@1a2bcd56] was not registered for synchronization because synchronization is not active
2021-04-24 18:56:21.535  INFO 19080 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Starting...
2021-04-24 18:56:21.681  INFO 19080 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Start completed.
JDBC Connection [HikariProxyConnection@662000775 wrapping com.mysql.cj.jdbc.ConnectionImpl@4f0cab0a] will not be managed by Spring
==>  Preparing: SELECT id,username AS name,age,email FROM t_user
==> Parameters: 
<==    Columns: id, name, age, email
<==        Row: 1, Jone, 18, test1@baomidou.com
<==        Row: 2, Jack, 20, test2@baomidou.com
<==        Row: 3, Tom, 28, test3@baomidou.com
<==        Row: 4, Sandy, 21, test4@baomidou.com
<==        Row: 5, Billie, 24, test5@baomidou.com
<==        Row: 6, Oneby, 21, Oneby@baomidou.com
<==      Total: 6
Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@1a2bcd56]
User(id=1, name=Jone, age=18, email=test1@baomidou.com)
User(id=2, name=Jack, age=20, email=test2@baomidou.com)
User(id=3, name=Tom, age=28, email=test3@baomidou.com)
User(id=4, name=Sandy, age=21, email=test4@baomidou.com)
User(id=5, name=Billie, age=24, email=test5@baomidou.com)
User(id=6, name=Oneby, age=21, email=Oneby@baomidou.com)

我们为 MybatisPlus 添加自动分页插件,并指明数据库类型为 MySQL

/**
 * @Author Oneby
 * @Date 2021/4/24 18:42
 */
@Configuration
@MapperScan("com.oneby.mapper")
public class MybatisPlusConfig {

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }

}

日志提示我们 MybatisPlusInterceptor 插件已经成功注册。

有没有想过为什么会执行 SELECT COUNT(*) FROM t_user 语句?这是因为执行 Page<User> userPage = userMapper.selectPage(page, null); 方法,返回值为 Page<User> userPage 对象,需要设置 userPage 对象的 total 字段的值(总条数)

Registered plugin: 'com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor@70e889e9'

Creating a new SqlSession
SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@44536de4] was not registered for synchronization because synchronization is not active
2021-04-24 18:57:50.254  INFO 8344 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Starting...
2021-04-24 18:57:50.379  INFO 8344 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Start completed.
JDBC Connection [HikariProxyConnection@418731780 wrapping com.mysql.cj.jdbc.ConnectionImpl@67cefd84] will not be managed by Spring
==>  Preparing: SELECT COUNT(*) FROM t_user
==> Parameters: 
<==    Columns: COUNT(*)
<==        Row: 6
<==      Total: 1
==>  Preparing: SELECT id, username AS name, age, email FROM t_user ORDER BY age ASC LIMIT ?
==> Parameters: 3(Long)
<==    Columns: id, name, age, email
<==        Row: 1, Jone, 18, test1@baomidou.com
<==        Row: 2, Jack, 20, test2@baomidou.com
<==        Row: 6, Oneby, 21, Oneby@baomidou.com
<==      Total: 3
Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@44536de4]
User(id=1, name=Jone, age=18, email=test1@baomidou.com)
User(id=2, name=Jack, age=20, email=test2@baomidou.com)
User(id=6, name=Oneby, age=21, email=Oneby@baomidou.com)

PaginationInnerInterceptor 插件的属性

属性名类型默认值描述
overflowbooleanfalse溢出总页数后是否进行处理(默认不处理,参见 插件#continuePage 方法)
maxLimitLong单页分页条数限制(默认无限制,参见 插件#handlerLimit 方法)
dbTypeDbType数据库类型(根据类型获取应使用的分页方言,参见 插件#findIDialect 方法)
dialectIDialect方言实现类(参见 插件#findIDialect 方法)

PaginationInnerInterceptor 的原理

PaginationInnerInterceptor 实现了 InnerInterceptor 接口,主要看重写的 willDoQuery() 方法和 beforeQuery() 方法

  1. willDoQuery() 方法:
    1. 执行 MappedStatement countMs = buildCountMappedStatement(ms, page.countId()); 方法:传入原生 MappedStatement ms 对象和 page.countId() 构建新的 MappedStatement countMs 对象
    2. 如果 countMs 不为空,则执行 countSql = countMs.getBoundSql(parameter); 构建 BoundSql countSql 对象;如果 countMs 为空,则使用 MP 自带的构建器进行构建
    3. 执行 Object result = executor.query(countMs, parameter, rowBounds, resultHandler, cacheKey, countSql).get(0); 查询总条数,并执行 page.setTotal(result == null ? 0L : Long.parseLong(result.toString())); 将总条数设置到 page.total 字段中
    4. 执行 continuePage(page); 方法表示接下来是是否继续执行分页查询
  2. beforeQuery() 方法
    1. 执行 String buildSql = boundSql.getSql(); 方法获取原生 SQL 语句
    2. 如果 page.orders() 不为空则表示需要排序,则执行 buildSql = this.concatOrderBy(buildSql, orders); 方法拼接排序子句
    3. 执行 handlerLimit(page); 方法,如果处理超出分页条数限制,默认归为限制数
    4. 执行 IDialect dialect = findIDialect(executor); 方法获取数据库方言的分页语法
    5. 执行 DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize()); 构建带分页的 DialectModel 对象
    6. 然后再执行各种方法把带分页的 SQL 语句和参数塞进 mpBoundSql 对象中,不过就奇了怪了,这个 mpBoundSql 对象既不是什么静态共享对象,也没有返回,往里面塞 SQL 语句和参数起个什么作用?
/**
 * 分页拦截器
 * <p>
 * 默认对 left join 进行优化,虽然能优化count,但是加上分页的话如果1对多本身结果条数就是不正确的
 *
 * @author hubin
 * @since 3.4.0
 */
@Data
@NoArgsConstructor
@SuppressWarnings({"rawtypes"})
public class PaginationInnerInterceptor implements InnerInterceptor {

    protected static final List<SelectItem> COUNT_SELECT_ITEM = Collections.singletonList(defaultCountSelectItem());
    protected static final Map<String, MappedStatement> countMsCache = new ConcurrentHashMap<>();
    protected final Log logger = LogFactory.getLog(this.getClass());

    /**
     * 获取jsqlparser中count的SelectItem
     */
    private static SelectItem defaultCountSelectItem() {
        Function function = new Function();
        function.setName("COUNT");
        function.setAllColumns(true);
        return new SelectExpressionItem(function);
    }

    /**
     * 溢出总页数后是否进行处理
     */
    protected boolean overflow;
    /**
     * 单页分页条数限制
     */
    protected Long maxLimit;
    /**
     * 数据库类型
     * <p>
     * 查看 {@link #findIDialect(Executor)} 逻辑
     */
    private DbType dbType;
    /**
     * 方言实现类
     * <p>
     * 查看 {@link #findIDialect(Executor)} 逻辑
     */
    private IDialect dialect;

    public PaginationInnerInterceptor(DbType dbType) {
        this.dbType = dbType;
    }

    public PaginationInnerInterceptor(IDialect dialect) {
        this.dialect = dialect;
    }

    /**
     * 这里进行count,如果count为0这返回false(就是不再执行sql了)
     */
    @Override
    public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
        if (page == null || page.getSize() < 0 || !page.isSearchCount()) {
            return true;
        }

        BoundSql countSql;
        MappedStatement countMs = buildCountMappedStatement(ms, page.countId());
        if (countMs != null) {
            countSql = countMs.getBoundSql(parameter);
        } else {
            countMs = buildAutoCountMappedStatement(ms);
            String countSqlStr = autoCountSql(page.optimizeCountSql(), boundSql.getSql());
            PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
            countSql = new BoundSql(countMs.getConfiguration(), countSqlStr, mpBoundSql.parameterMappings(), parameter);
            PluginUtils.setAdditionalParameter(countSql, mpBoundSql.additionalParameters());
        }

        CacheKey cacheKey = executor.createCacheKey(countMs, parameter, rowBounds, countSql);
        Object result = executor.query(countMs, parameter, rowBounds, resultHandler, cacheKey, countSql).get(0);
        page.setTotal(result == null ? 0L : Long.parseLong(result.toString()));
        return continuePage(page);
    }

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
        if (null == page) {
            return;
        }

        // 处理 orderBy 拼接
        boolean addOrdered = false;
        String buildSql = boundSql.getSql();
        List<OrderItem> orders = page.orders();
        if (!CollectionUtils.isEmpty(orders)) {
            addOrdered = true;
            buildSql = this.concatOrderBy(buildSql, orders);
        }

        // size 小于 0 不构造分页sql
        if (page.getSize() < 0) {
            if (addOrdered) {
                PluginUtils.mpBoundSql(boundSql).sql(buildSql);
            }
            return;
        }

        handlerLimit(page);
        IDialect dialect = findIDialect(executor);

        final Configuration configuration = ms.getConfiguration();
        DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
        PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);

        List<ParameterMapping> mappings = mpBoundSql.parameterMappings();
        Map<String, Object> additionalParameter = mpBoundSql.additionalParameters();
        model.consumers(mappings, configuration, additionalParameter);
        mpBoundSql.sql(model.getDialectSql());
        mpBoundSql.parameterMappings(mappings);
    }
    
    /**
     * count 查询之后,是否继续执行分页
     *
     * @param page 分页对象
     * @return 是否
     */
    protected boolean continuePage(IPage<?> page) {
        if (page.getTotal() <= 0) {
            return false;
        }
        if (page.getCurrent() > page.getPages()) {
            if (overflow) {
                //溢出总页数处理
                handlerOverflow(page);
            } else {
                // 超过最大范围,未设置溢出逻辑中断 list 执行
                return false;
            }
        }
        return true;
    }
    
    /**
     * 处理超出分页条数限制,默认归为限制数
     *
     * @param page IPage
     */
    protected void handlerLimit(IPage<?> page) {
        final long size = page.getSize();
        Long pageMaxLimit = page.maxLimit();
        Long limit = pageMaxLimit != null ? pageMaxLimit : maxLimit;
        if (limit != null && limit > 0 && size > limit) {
            page.setSize(limit);
        }
    }

3、防止全表更新与删除插件

防止全表更新与删除:BlockAttackInnerInterceptor

BlockAttackInnerInterceptor 会使用 commons-lang3Stringutils 类,我们引入 commons-lang3 的依赖

<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-lang3</artifactId>
</dependency>

MybatisPlusConfig 配置类中注册 BlockAttackInnerInterceptor 防止全表更新与删除插件

/**
 * @Author Oneby
 * @Date 2021/4/24 18:42
 */
@Configuration
@MapperScan("com.oneby.mapper")
public class MybatisPlusConfig {

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        interceptor.addInnerInterceptor(new BlockAttackInnerInterceptor());
        return interceptor;
    }

}

测试代码:执行全表删除

@Test
public void BlockAttackInnerInterceptor() {
    int count = userMapper.delete(null);
    System.out.println("删除了:" + count + "行");
}

程序抛出异常:Error updating database. Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: Prohibition of full table deletion,不允许全表删除

Creating a new SqlSession
SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@2aff9dff] was not registered for synchronization because synchronization is not active
2021-04-24 22:39:54.361  INFO 20340 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Starting...
2021-04-24 22:39:54.518  INFO 20340 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Start completed.
JDBC Connection [HikariProxyConnection@111702054 wrapping com.mysql.cj.jdbc.ConnectionImpl@ef60710] will not be managed by Spring
Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@2aff9dff]

org.mybatis.spring.MyBatisSystemException: nested exception is org.apache.ibatis.exceptions.PersistenceException: 
### Error updating database.  Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: Prohibition of full table deletion
### The error may exist in com/oneby/mapper/UserMapper.java (best guess)
### The error may involve com.oneby.mapper.UserMapper.delete
### The error occurred while executing an update
### Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: Prohibition of full table deletion

BlockAttackInnerInterceptor 的使用建议

  1. SQL执行分析拦截器,只支持 MySQL5.6.3以上版本
  2. 该插件的作用是分析 DELETEUPDATE 语句,防止小白或者恶意进行 DELETEUPDATE 全表操作
  3. 只建议在开发环境中使用,不建议在生产环境使用

BlockAttackInnerInterceptor 的原理

执行全表删除时,在 Plugin#invoke() 方法处抛了异常,具体位置是执行 interceptor.intercept(new Invocation(target, method, args)); 的时候抛了异常:Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: Prohibition of full table deletion

/**
 * @author Clinton Begin
 */
public class Plugin implements InvocationHandler {

  private final Object target;
  private final Interceptor interceptor;
  private final Map<Class<?>, Set<Method>> signatureMap;

  private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
    this.target = target;
    this.interceptor = interceptor;
    this.signatureMap = signatureMap;
  }

  public static Object wrap(Object target, Interceptor interceptor) {
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    Class<?> type = target.getClass();
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      if (methods != null && methods.contains(method)) {
        return interceptor.intercept(new Invocation(target, method, args));
      }
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }

image-20210424225837275

这得回到 MybatisPlusInterceptor#intercept() 方法:在该方法内部调用了内部拦截器的 innerInterceptor.beforePrepare(sh, connections, transactionTimeout);

public class MybatisPlusInterceptor implements Interceptor {

    @Setter
    private List<InnerInterceptor> interceptors = new ArrayList<>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        Object[] args = invocation.getArgs();
        if (target instanceof Executor) {
            // ...
        } else {
            // StatementHandler
            final StatementHandler sh = (StatementHandler) target;
            Connection connections = (Connection) args[0];
            Integer transactionTimeout = (Integer) args[1];
            for (InnerInterceptor innerInterceptor : interceptors) {
                innerInterceptor.beforePrepare(sh, connections, transactionTimeout);
            }
        }
        return invocation.proceed();
    }

BlockAttackInnerInterceptor 拦截器重写了 InnerInterceptor 接口的 beforePrepare() 方法。执行 parserMulti(boundSql.getSql(), null); 方法时抛了异常

/**
 * 攻击 SQL 阻断解析器,防止全表更新与删除
 *
 * @author hubin
 * @since 3.4.0
 */
public class BlockAttackInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler handler = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = handler.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            if (InterceptorIgnoreHelper.willIgnoreBlockAttack(ms.getId())) return;
            BoundSql boundSql = handler.boundSql();
            parserMulti(boundSql.getSql(), null);
        }
    }

我们进入到 BlockAttackInnerInterceptor 的抽象父类 JsqlParserSupport 中,在执行 processParser(statement, i, sql, obj); --> this.processDelete((Delete) statement, index, sql, obj); 方法时,解析 SQL 语句抛出了异常

/**
 * https://github.com/JSQLParser/JSqlParser
 *
 * @author miemie
 * @since 2020-06-22
 */
public abstract class JsqlParserSupport {

    /**
     * 日志
     */
    protected final Log logger = LogFactory.getLog(this.getClass());

    public String parserSingle(String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("original SQL: " + sql);
        }
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            return processParser(statement, 0, sql, obj);
        } catch (JSQLParserException e) {
            throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e, sql);
        }
    }

    public String parserMulti(String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("original SQL: " + sql);
        }
        try {
            // fixed github pull/295
            StringBuilder sb = new StringBuilder();
            Statements statements = CCJSqlParserUtil.parseStatements(sql);
            int i = 0;
            for (Statement statement : statements.getStatements()) {
                if (i > 0) {
                    sb.append(StringPool.SEMICOLON);
                }
                sb.append(processParser(statement, i, sql, obj));
                i++;
            }
            return sb.toString();
        } catch (JSQLParserException e) {
            throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e, sql);
        }
    }
    
    /**
     * 执行 SQL 解析
     *
     * @param statement JsqlParser Statement
     * @return sql
     */
    protected String processParser(Statement statement, int index, String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("SQL to parse, SQL: " + sql);
        }
        if (statement instanceof Insert) {
            this.processInsert((Insert) statement, index, sql, obj);
        } else if (statement instanceof Select) {
            this.processSelect((Select) statement, index, sql, obj);
        } else if (statement instanceof Update) {
            this.processUpdate((Update) statement, index, sql, obj);
        } else if (statement instanceof Delete) {
            this.processDelete((Delete) statement, index, sql, obj);
        }
        sql = statement.toString();
        if (logger.isDebugEnabled()) {
            logger.debug("parse the finished SQL: " + sql);
        }
        return sql;
    }

this 就是 BlockAttackInnerInterceptor 拦截器对象,解析 DELETE 语句的方法调用链为 processDelete() --> checkWhere() --> fullMatch()

fullMatch() 方法中会解析 DELETE 语句中 WHERE 子句的各种判断条件,进而检查删除操作是否安全,避免全表删除操作

/**
 * 攻击 SQL 阻断解析器,防止全表更新与删除
 *
 * @author hubin
 * @since 3.4.0
 */
public class BlockAttackInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler handler = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = handler.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            if (InterceptorIgnoreHelper.willIgnoreBlockAttack(ms.getId())) return;
            BoundSql boundSql = handler.boundSql();
            parserMulti(boundSql.getSql(), null);
        }
    }

    @Override
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        this.checkWhere(delete.getTable().getName(), delete.getWhere(), "Prohibition of full table deletion");
    }

    @Override
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        this.checkWhere(update.getTable().getName(), update.getWhere(), "Prohibition of table update operation");
    }

    protected void checkWhere(String tableName, Expression where, String ex) {
        Assert.isFalse(this.fullMatch(where, this.getTableLogicField(tableName)), ex);
    }

    private boolean fullMatch(Expression where, String logicField) {
        if (where == null) {
            return true;
        }
        if (StringUtils.isNotBlank(logicField) && (where instanceof BinaryExpression)) {

            BinaryExpression binaryExpression = (BinaryExpression) where;
            if (StringUtils.equals(binaryExpression.getLeftExpression().toString(), logicField) || StringUtils.equals(binaryExpression.getRightExpression().toString(), logicField)) {
                return true;
            }
        }

        if (where instanceof EqualsTo) {
            // example: 1=1
            EqualsTo equalsTo = (EqualsTo) where;
            return StringUtils.equals(equalsTo.getLeftExpression().toString(), equalsTo.getRightExpression().toString());
        } else if (where instanceof NotEqualsTo) {
            // example: 1 != 2
            NotEqualsTo notEqualsTo = (NotEqualsTo) where;
            return !StringUtils.equals(notEqualsTo.getLeftExpression().toString(), notEqualsTo.getRightExpression().toString());
        } else if (where instanceof OrExpression) {

            OrExpression orExpression = (OrExpression) where;
            return fullMatch(orExpression.getLeftExpression(), logicField) || fullMatch(orExpression.getRightExpression(), logicField);
        } else if (where instanceof AndExpression) {

            AndExpression andExpression = (AndExpression) where;
            return fullMatch(andExpression.getLeftExpression(), logicField) && fullMatch(andExpression.getRightExpression(), logicField);
        } else if (where instanceof Parenthesis) {
            // example: (1 = 1)
            Parenthesis parenthesis = (Parenthesis) where;
            return fullMatch(parenthesis.getExpression(), logicField);
        }

        return false;
    }

4、SQL 性能规范插件

SQL 性能规范:IllegalSQLInnerInterceptor

MybatisPlusConfig 配置类中注册 IllegalSQLInnerInterceptor SQL 性能规范插件

注意:还是只推荐在开发环境下使用,有助于帮我们找出不规范的 SQL 语句

/**
 * @Author Oneby
 * @Date 2021/4/24 18:42
 */
@Configuration
@MapperScan("com.oneby.mapper")
public class MybatisPlusConfig {

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        interceptor.addInnerInterceptor(new BlockAttackInnerInterceptor());
        interceptor.addInnerInterceptor(new IllegalSQLInnerInterceptor());
        return interceptor;
    }

}

测试代码:不加条件的分页查询

@Test
public void IllegalSQLInnerInterceptor() {
    Page<User> page = new Page<>(1, 3);
    page.addOrder(OrderItem.asc("age"));
    Page<User> userPage = userMapper.selectPage(page, null);
    for (User user : userPage.getRecords()) {
        System.out.println(user);
    }
}

查询抛异常:Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: 非法SQL,必须要有where条件。果然够规范,select 不加 where 都不让我查

Creating a new SqlSession
SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@77b919a3] was not registered for synchronization because synchronization is not active
2021-04-24 23:24:05.996  INFO 19332 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Starting...
2021-04-24 23:24:06.105  INFO 19332 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Start completed.
JDBC Connection [HikariProxyConnection@1487059223 wrapping com.mysql.cj.jdbc.ConnectionImpl@48904d5a] will not be managed by Spring
Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@77b919a3]

org.mybatis.spring.MyBatisSystemException: nested exception is org.apache.ibatis.exceptions.PersistenceException: 
### Error querying database.  Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: 非法SQL,必须要有where条件
### The error may exist in com/oneby/mapper/UserMapper.java (best guess)
### The error may involve com.oneby.mapper.UserMapper.selectPage_mpCount
### The error occurred while executing a query
### Cause: com.baomidou.mybatisplus.core.exceptions.MybatisPlusException: 非法SQL,必须要有where条件

IllegalSQLInnerInterceptor 的原理

IllegalSQLInnerInterceptor 插件拦截 StatementHandler 对象的相关方法,该类重写了 InnerInterceptor 接口的 beforePrepare() 方法

beforePrepare() 方法,执行 parserSingle(originalSql, connection); 方法分析 SQL 语句是否规范

/**
 * 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
 * <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
 * <br>
 * <p>拦截SQL类型的场景</p>
 * <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
 * <p>必须使用索引好处,</p>
 * <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
 * <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
 * <br>
 * <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
 * <p>https://gaoxianglong.github.io/shark</p>
 * <p>SQL尽量单表执行的好处</p>
 * <p>2.1 查询条件简单、易于开理解和维护;</p>
 * <p>2.2 扩展性极强;(可为分库分表做准备)</p>
 * <p>2.3 缓存利用率高;</p>
 * <p>2.在字段上使用函数</p>
 * <br>
 * <p>3.where条件为空</p>
 * <p>4.where条件使用了 !=</p>
 * <p>5.where条件使用了 not 关键字</p>
 * <p>6.where条件使用了 or 关键字</p>
 * <p>7.where条件使用了 使用子查询</p>
 *
 * @author willenfoo
 * @since 3.4.0
 */
public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

    /**
     * 缓存验证结果,提高性能
     */
    private static final Set<String> cacheValidResult = new HashSet<>();
    /**
     * 缓存表的索引信息
     */
    private static final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpStatementHandler.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())
            || SqlParserHelper.getSqlParserInfo(ms)) return;
        BoundSql boundSql = mpStatementHandler.boundSql();
        String originalSql = boundSql.getSql();
        logger.debug("检查SQL是否合规,SQL:" + originalSql);
        String md5Base64 = EncryptUtils.md5Base64(originalSql);
        if (cacheValidResult.contains(md5Base64)) {
            logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
            return;
        }
        parserSingle(originalSql, connection);
        //缓存验证结果
        cacheValidResult.add(md5Base64);
    }

我们进入到 IllegalSQLInnerInterceptor 的抽象父类 JsqlParserSupport 中,在执行 parserSingle(originalSql, connection); --> this.processSelect((Select) statement, index, sql, obj); 方法时,解析 SQL 语句抛出了异常

/**
 * https://github.com/JSQLParser/JSqlParser
 *
 * @author miemie
 * @since 2020-06-22
 */
public abstract class JsqlParserSupport {

    /**
     * 日志
     */
    protected final Log logger = LogFactory.getLog(this.getClass());

    public String parserSingle(String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("original SQL: " + sql);
        }
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            return processParser(statement, 0, sql, obj);
        } catch (JSQLParserException e) {
            throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e, sql);
        }
    }

    /**
     * 执行 SQL 解析
     *
     * @param statement JsqlParser Statement
     * @return sql
     */
    protected String processParser(Statement statement, int index, String sql, Object obj) {
        if (logger.isDebugEnabled()) {
            logger.debug("SQL to parse, SQL: " + sql);
        }
        if (statement instanceof Insert) {
            this.processInsert((Insert) statement, index, sql, obj);
        } else if (statement instanceof Select) {
            this.processSelect((Select) statement, index, sql, obj);
        } else if (statement instanceof Update) {
            this.processUpdate((Update) statement, index, sql, obj);
        } else if (statement instanceof Delete) {
            this.processDelete((Delete) statement, index, sql, obj);
        }
        sql = statement.toString();
        if (logger.isDebugEnabled()) {
            logger.debug("parse the finished SQL: " + sql);
        }
        return sql;
    }

this 就是 IllegalSQLInnerInterceptor 拦截器对象,在该类中重写了抽象父类 JsqlParserSupport 中的 processSelect()processUpdate()processDelete() 方法

在这些方法中主要是通过 validWhere(where, table, (Connection) obj);validJoins(joins, table, (Connection) obj); 方法检查 SQL 语句的性能规范

public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

    /**
     * 缓存验证结果,提高性能
     */
    private static final Set<String> cacheValidResult = new HashSet<>();
    /**
     * 缓存表的索引信息
     */
    private static final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpStatementHandler.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())
            || SqlParserHelper.getSqlParserInfo(ms)) return;
        BoundSql boundSql = mpStatementHandler.boundSql();
        String originalSql = boundSql.getSql();
        logger.debug("检查SQL是否合规,SQL:" + originalSql);
        String md5Base64 = EncryptUtils.md5Base64(originalSql);
        if (cacheValidResult.contains(md5Base64)) {
            logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
            return;
        }
        parserSingle(originalSql, connection);
        //缓存验证结果
        cacheValidResult.add(md5Base64);
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
        Expression where = plainSelect.getWhere();
        Assert.notNull(where, "非法SQL,必须要有where条件");
        Table table = (Table) plainSelect.getFromItem();
        List<Join> joins = plainSelect.getJoins();
        validWhere(where, table, (Connection) obj);
        validJoins(joins, table, (Connection) obj);
    }

    @Override
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        Expression where = update.getWhere();
        Assert.notNull(where, "非法SQL,必须要有where条件");
        Table table = update.getTable();
        List<Join> joins = update.getJoins();
        validWhere(where, table, (Connection) obj);
        validJoins(joins, table, (Connection) obj);
    }

    @Override
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        Expression where = delete.getWhere();
        Assert.notNull(where, "非法SQL,必须要有where条件");
        Table table = delete.getTable();
        List<Join> joins = delete.getJoins();
        validWhere(where, table, (Connection) obj);
        validJoins(joins, table, (Connection) obj);
    }

5、乐观锁插件

乐观锁插件:OptimisticLockerInnerInterceptor

MybatisPlusConfig 配置类中注册 OptimisticLockerInnerInterceptor 乐观锁插件

/**
 * @Author Oneby
 * @Date 2021/4/24 18:42
 */
@Configuration
@MapperScan("com.oneby.mapper")
public class MybatisPlusConfig {

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        interceptor.addInnerInterceptor(new BlockAttackInnerInterceptor());
        interceptor.addInnerInterceptor(new IllegalSQLInnerInterceptor());
        interceptor.addInnerInterceptor(new OptimisticLockerInnerInterceptor());
        return interceptor;
    }

}

t_user 表和 User 实体类中新加一个 version 字段

  1. 支持的数据类型只有:intIntegerlongLongDateTimestampLocalDateTime
  2. 整数类型下 newVersion = oldVersion + 1
  3. newVersion 会回写到 entity
  4. 仅支持 updateById(id)update(entity, wrapper) 方法
  5. update(entity, wrapper) 方法下, wrapper 不能复用!!!
/**
 * @Author Oneby
 * @Date 2021/4/18 17:53
 */
@Data
@AllArgsConstructor
@NoArgsConstructor
@TableName("t_user")
public class User extends Model<User> {

    @TableId(type = IdType.AUTO)
    private Long id;

    @TableField("username")
    private String name;

    private Integer age;

    private String email;

    @Version
    private Integer version;

}

测试代码:在更新数据的时候带上版本号

@Test
public void OptimisticLockerInnerInterceptor() {
    User user = new User();
    user.setId(1L);
    user.setName("Oneby");
    user.setVersion(1);
    int count = userMapper.updateById(user);
    System.out.println("更新行数:" + count);
}

SQL 语句:UPDATE t_user SET username=?, version=? WHERE id=? AND version=?,可以看到,在更新时会对 version 字段进行相等判断,更新时会将 version 字段值加 1 写回数据库记录中

Creating a new SqlSession
SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@54d8c20d] was not registered for synchronization because synchronization is not active
2021-04-25 07:36:36.763  INFO 20980 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Starting...
2021-04-25 07:36:36.919  INFO 20980 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Start completed.
JDBC Connection [HikariProxyConnection@1296670053 wrapping com.mysql.cj.jdbc.ConnectionImpl@313f8301] will not be managed by Spring
==>  Preparing: UPDATE t_user SET username=?, version=? WHERE id=? AND version=?
==> Parameters: Heygo(String), 2(Integer), 1(Long), 1(Integer)
<==    Updates: 1
Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@54d8c20d]
更新行数:1

如下 SQL 日志展示了乐观锁更新失败的情况

Creating a new SqlSession
SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@7100dea] was not registered for synchronization because synchronization is not active
2021-04-25 07:38:10.198  INFO 7624 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Starting...
2021-04-25 07:38:10.323  INFO 7624 --- [           main] com.zaxxer.hikari.HikariDataSource       : HikariPool-1 - Start completed.
JDBC Connection [HikariProxyConnection@912440831 wrapping com.mysql.cj.jdbc.ConnectionImpl@1bb15351] will not be managed by Spring
==>  Preparing: UPDATE t_user SET username=?, version=? WHERE id=? AND version=?
==> Parameters: Oneby(String), 2(Integer), 1(Long), 1(Integer)
<==    Updates: 0
Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@7100dea]
更新行数:0

OptimisticLockerInnerInterceptor 的原理

OptimisticLockerInnerInterceptor 乐观锁实现方式:当要更新一条记录的时候,希望这条记录没有被别人更新

  1. 取出记录时,获取当前version
  2. 更新时,带上这个version
  3. 执行更新时, set version = newVersion where version = oldVersion
  4. 如果version不对,就更新失败

MybatisPlusInterceptor#intercept() 方法中:当 SQL 语句的执行类型为 update 时,会执行 InnerInterceptor 对象的 beforeUpdate() 方法

OptimisticLockerInnerInterceptor 重写了 InnerInterceptor 接口的 beforeUpdate() ,该方法主要是调用 doOptimisticLocker(map, ms.getId()); 方法构造带 version 字段的 SQL 语句

doOptimisticLocker() 方法的执行逻辑:

  1. 执行 TableFieldInfo fieldInfo = tableInfo.getVersionFieldInfo(); 方法获取表中的 version 字段,利用反射 Field versionField = fieldInfo.getField();,先看下表中是否存在 version 字段是否为 null,如果为 null 则没有必要处理,直接返回接口
  2. version != null,则执行 Object updatedVersionVal = this.getUpdatedVersionVal(fieldInfo.getPropertyType(), originalVersionVal); 获取新的 version 字段值
  3. 如果是更新操作,则使用 UpdateWrapper<?> uw = new UpdateWrapper<>(); 对象,构造更新的条件:uw.eq(versionColumn, originalVersionVal);map.put(Constants.WRAPPER, uw);
  4. 如果是其他操作,则使用 AbstractWrapper<?, ?, ?> aw = (AbstractWrapper<?, ?, ?>) map.getOrDefault(Constants.WRAPPER, null); 对象,构建筛选条件:aw.apply(versionColumn + " = {0}", originalVersionVal);
  5. 执行 versionField.set(et, updatedVersionVal); 方法将新的 version 值设置到 tableInfo.versionFieldInfo 字段中
/**
 * Optimistic Lock Light version
 * <p>Intercept on {@link Executor}.update;</p>
 * <p>Support version types: int/Integer, long/Long, java.util.Date, java.sql.Timestamp</p>
 * <p>For extra types, please define a subclass and override {@code getUpdatedVersionVal}() method.</p>
 * <br>
 * <p>How to use?</p>
 * <p>(1) Define an Entity and add {@link Version} annotation on one entity field.</p>
 * <p>(2) Add {@link OptimisticLockerInnerInterceptor} into mybatis plugin.</p>
 * <br>
 * <p>How to work?</p>
 * <p>if update entity with version column=1:</p>
 * <p>(1) no {@link OptimisticLockerInnerInterceptor}:</p>
 * <p>SQL: update tbl_test set name='abc' where id=100001;</p>
 * <p>(2) add {@link OptimisticLockerInnerInterceptor}:</p>
 * <p>SQL: update tbl_test set name='abc',version=2 where id=100001 and version=1;</p>
 *
 * @author yuxiaobin
 * @since 3.4.0
 */
@SuppressWarnings({"unchecked"})
public class OptimisticLockerInnerInterceptor implements InnerInterceptor {

    private static final String PARAM_UPDATE_METHOD_NAME = "update";

    @Override
    public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
            return;
        }
        if (parameter instanceof Map) {
            Map<String, Object> map = (Map<String, Object>) parameter;
            doOptimisticLocker(map, ms.getId());
        }
    }

    protected void doOptimisticLocker(Map<String, Object> map, String msId) {
        //updateById(et), update(et, wrapper);
        Object et = map.getOrDefault(Constants.ENTITY, null);
        if (et != null) {
            // entity
            String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
            TableInfo tableInfo = TableInfoHelper.getTableInfo(et.getClass());
            if (tableInfo == null || !tableInfo.isWithVersion()) {
                return;
            }
            try {
                TableFieldInfo fieldInfo = tableInfo.getVersionFieldInfo();
                Field versionField = fieldInfo.getField();
                // 旧的 version 值
                Object originalVersionVal = versionField.get(et);
                if (originalVersionVal == null) {
                    return;
                }
                String versionColumn = fieldInfo.getColumn();
                // 新的 version 值
                Object updatedVersionVal = this.getUpdatedVersionVal(fieldInfo.getPropertyType(), originalVersionVal);
                if (PARAM_UPDATE_METHOD_NAME.equals(methodName)) {
                    AbstractWrapper<?, ?, ?> aw = (AbstractWrapper<?, ?, ?>) map.getOrDefault(Constants.WRAPPER, null);
                    if (aw == null) {
                        UpdateWrapper<?> uw = new UpdateWrapper<>();
                        uw.eq(versionColumn, originalVersionVal);
                        map.put(Constants.WRAPPER, uw);
                    } else {
                        aw.apply(versionColumn + " = {0}", originalVersionVal);
                    }
                } else {
                    map.put(Constants.MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
                }
                versionField.set(et, updatedVersionVal);
            } catch (IllegalAccessException e) {
                throw ExceptionUtils.mpe(e);
            }
        }
    }

    /**
     * This method provides the control for version value.<BR>
     * Returned value type must be the same as original one.
     *
     * @param originalVersionVal ignore
     * @return updated version val
     */
    protected Object getUpdatedVersionVal(Class<?> clazz, Object originalVersionVal) {
        if (long.class.equals(clazz) || Long.class.equals(clazz)) {
            return ((long) originalVersionVal) + 1;
        } else if (int.class.equals(clazz) || Integer.class.equals(clazz)) {
            return ((int) originalVersionVal) + 1;
        } else if (Date.class.equals(clazz)) {
            return new Date();
        } else if (Timestamp.class.equals(clazz)) {
            return new Timestamp(System.currentTimeMillis());
        } else if (LocalDateTime.class.equals(clazz)) {
            return LocalDateTime.now();
        }
        //not supported type, return original val.
        return originalVersionVal;
    }
}
  • 6
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值