Mybatis自定义拦截器开发问题解决记录

开发背景说明

对mapper添加@DataPermission(resource = “xxx”)的接口进行权限sql拼接,不同的resource进行不同的数据权限解析,完成原有的Sql基础上进行权限sql拼接。

自定义interceptor开发

// 简化后的代码

@Slf4j
@Intercepts({@Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})})
public class DataPermissionInterceptor extends implements Interceptor {
  @Override
  public Object intercept(Invocation invocation) throws Throwable {
    try {
      DataPermission dataPermission = getDataPermissionAnno(invocation);
      if (dataPermission != null) {
        doHandler(sqlDataPermission, invocation);
      }
    } finally {
      return invocation.proceed();
    }
  }
}

 public void doHandler(Invocation invocation, DataPermission dataPermission) {
    String originalSql = this.getOriginalSql(invocation);
    if (Objects.isNull(originalSql) || originalSql.isEmpty() || Objects.isNull(mapperId)
        || mapperId.isEmpty()) {
      return;
    }
    AccessDto accessDto = this.extractAccessDto();
    String currentSql = this.handlerOriginalSql(originalSql, accessDto);
    this.resetCurrentSql(invocation, currentSql);
  }

@SneakyThrows
private String handlerOriginalSql(String originalSql, AccessDto accessDto) {
    Statement stmt = CCJSqlParserUtil.parse(originalSql);
    stmt.accept(new StatementVisitorAdapter() {

      @SneakyThrows
      @Override
      public void visit(Select select) {
        PlainSelect selectBody = (PlainSelect) select.getSelectBody();
        final String[] tableName = {null};
        selectBody.getFromItem().accept(new FromItemVisitorAdapter() {
          @Override
          public void visit(Table table) {
            super.visit(table);
            tableName[0] = table.getName();
          }
        });
        Alias fromAlias = selectBody.getFromItem().getAlias();
        List<Expression> rightExpressions = expressions(fromAlias, tableName[0], accessDto);
        Expression where = selectBody.getWhere();
        if (where != null) {
          where = CCJSqlParserUtil.parseCondExpression("(" + where + ")");
          rightExpressions.add(0, where);
        }
        Expression andExpression = rightExpressions.get(0);
        for (int i = 1; i < rightExpressions.size(); i++) {
          andExpression = new AndExpression(andExpression, rightExpressions.get(i));
        }
        selectBody.setWhere(andExpression);
      }
    });
    return stmt.toString();
  }
  
private List<Expression> expressions(Alias tableAlias, String tableName,AccessDto accessDto) {
    String alias = "";
    // 如果参数tableAlias不为null,则alias设置为表别名+".",
    // 否则将alias设置为表名+"."
    if (tableAlias != null) {
      alias = tableAlias.getName() + ".";
    } else if (StringUtils.isNotBlank(tableName)) {
      alias = tableName + ".";
    }
    List<Expression> rightExpressions = new ArrayList<>();

    if (accessDto != null && accessDto.getModelCodes() != null
        && CollectionUtils.isNotEmpty(
        accessDto.getModelCodes())) {
      List<Expression> expressions = accessDto.getModelCodes().stream().map(LongValue::new)
          .collect(Collectors.toList());
      ExpressionList expressionList = new ExpressionList(expressions);
      InExpression inExpression = new InExpression(new Column(alias + "model_code"),
          expressionList);
      rightExpressions.add(inExpression);
    }
    return rightExpressions;
  }

private DataPermission getDataPermissionAnno(Invocation invocation) {
    MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
    SqlDataPermission annotation = null;
    try {
      String id = mappedStatement.getId();
      String className = id.substring(0, id.lastIndexOf("."));
      String methodName = id.substring(id.lastIndexOf(".") + 1);
      final Method[] method = Class.forName(className).getMethods();
      for (Method me : method) {
        if (me.getName().equals(methodName) && me.isAnnotationPresent(SqlDataPermission.class)) {
          return me.getAnnotation(SqlDataPermission.class);
        }
      }
    } catch (Exception ex) {
      log.error("SqlDataPermission get error", ex);
    }
    return annotation;
  }
    
protected String getOriginalSql(Invocation invocation) {
    final Object[] args = invocation.getArgs();
    MappedStatement ms = (MappedStatement) args[0];
    Object parameterObject = args[1];
    BoundSql boundSql = ms.getBoundSql(parameterObject);
    return boundSql.getSql();
}

protected void resetCurrentSql(Invocation invocation, String sql) {
    final Object[] args = invocation.getArgs();
    MappedStatement statement = (MappedStatement) args[0];
    Object parameterObject = args[1];
    BoundSql boundSql = statement.getBoundSql(parameterObject);
    MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
    MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(),
        new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
    msObject.setValue("sqlSource.boundSql.sql", sql);
    args[0] = newStatement;
  }

private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
    MappedStatement.Builder builder =
        new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource,
            ms.getSqlCommandType());
    builder.resource(ms.getResource());
    builder.fetchSize(ms.getFetchSize());
    builder.statementType(ms.getStatementType());
    builder.keyGenerator(ms.getKeyGenerator());
    if (Objects.nonNull(ms.getKeyProperties()) && ms.getKeyProperties().length != 0) {
      StringBuilder keyProperties = new StringBuilder();
      for (String keyProperty : ms.getKeyProperties()) {
        keyProperties.append(keyProperty).append(",");
      }
      keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
      builder.keyProperty(keyProperties.toString());
    }
    builder.timeout(ms.getTimeout());
    builder.parameterMap(ms.getParameterMap());
    builder.resultMaps(ms.getResultMaps());
    builder.resultSetType(ms.getResultSetType());
    builder.cache(ms.getCache());
    builder.flushCacheRequired(ms.isFlushCacheRequired());
    builder.useCache(ms.isUseCache());
    return builder.build();
  }

  class BoundSqlSqlSource implements SqlSource {

    private BoundSql boundSql;

    public BoundSqlSqlSource(BoundSql boundSql) {
      this.boundSql = boundSql;
    }

    @Override
    public BoundSql getBoundSql(Object parameterObject) {
      return boundSql;
    }
  }

问题1:拦截器无效

排查原因如下:
拦截方法不正确
在这里插入图片描述
项目引入了pageHelper拦截器,query在原有基础上被处理过
在这里插入图片描述
添加@Signature
在这里插入图片描述
成功拦截

问题2:BoundSql重置后查询不生效

排查原因如下:
对invocation的BoundSql采用下面方式reset但是不生效
在这里插入图片描述
对invocation的args进行分析发现拦截器中还有一个BoundSql的参数还是与重置前相同
在这里插入图片描述
对BoundSql也进行重置
在这里插入图片描述
成功解决,Sql拦截重置生效

问题3:将自定义拦截器设置在Page拦截器之前

回到问题1,可以发现PageHelper的拦截器是在自定义拦截器之前。
调试代码验证如下:
在这里插入图片描述
在这里插入图片描述
由此验证在拦截链中,最后一个拦截器会先执行,因此如果希望将自定义SQL置于拦截器之前则需将拦截器放置最后。


@AutoConfigureAfter({PageHelperAutoConfiguration.class})
@Configuration
public class MybatisConfig implements CommandLineRunner {

  @Autowired
  private SqlSessionFactory sqlSessionFactory;


  /**
   * Callback used to run the bean.
   *
   * @param args incoming main method arguments
   * @throws Exception on error
   */
  @Override
  public void run(String... args) throws Exception {
    org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
    DataPermissionInterceptor dataPermissionInterceptor = new DataPermissionInterceptor();
    configuration.addInterceptor(dataPermissionInterceptor);
  }
}

使用该方式则原始代码的BoundSql重置拦截生效

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值