Mybatisplus拦截器

package com.dyna.common.config.mybatisplus;

import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.dyna.common.bean.AuthSystemLoginVo;
import com.dyna.common.context.CurrentUserContext;
import com.dyna.common.enums.YnEnum;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.util.TablesNamesFinder;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.util.List;

/**
 *
 */
@Slf4j
@Configuration
public class IbmpInnerInterceptor implements InnerInterceptor, ApplicationContextAware {

    private static final String TEST_FLAG_FIELD = "is_test";

    private static final List<String> IS_TEST_IGNORE_MAPPER_METHOD_LIST = Lists.newArrayList("com.dyna.common.mapper.ConfigMapper.selectList");

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new IbmpInnerInterceptor());
        return interceptor;
    }

    /**
     * {@link Executor#query(MappedStatement, Object, RowBounds, ResultHandler, CacheKey, BoundSql)} 操作前置处理
     * <p>
     * 改改sql啥的,注意这里会有sql注入的风险,一定要谨慎使用,确保不会留下什么漏洞
     *
     * @param executor Executor(可能是代理对象)
     * @param ms MappedStatement
     * @param parameter parameter
     * @param rowBounds rowBounds
     * @param resultHandler resultHandler
     * @param boundSql boundSql
     */
    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds,
        ResultHandler resultHandler, BoundSql boundSql) {
        String buildSql = boundSql.getSql();

        // 是否需要添加is_test查询条件
        if(!isNeedAddIsTestCondition(buildSql,ms,executor)){
            return;
        }
        String newBuildSql = this.addTestFlag(buildSql);

        // 只有sql发生变化了,才需要打印日志并重新绑定sql
        if (!StringUtils.equals(newBuildSql, buildSql)) {
            log.debug("修改前sql=【{}】", buildSql);
            log.debug("修改后sql=【{}】", newBuildSql);
            PluginUtils.mpBoundSql(boundSql).sql(newBuildSql);
        }
    }

    private boolean isNeedAddIsTestCondition(String boundSql,MappedStatement ms,Executor executor){
        String id = ms.getId();
        if (StringUtils.isBlank(id)) {
            return false;
        }
        // 特殊方法不处理
        if(IS_TEST_IGNORE_MAPPER_METHOD_LIST.contains(id)){
            return false;
        }
        IsTestFlagService isTestFlagService = this.getBean(IsTestFlagService.class);
        // 判断开关,开关关闭,不拼接查询条件,原sql执行
        if(!isTestFlagService.testIgnoreOpenFlag()){
            return false;
        }
        // 联表查询不处理
        if (StringUtils.containsIgnoreCase(boundSql, "join")) {
            return false;
        }
        // where条件有is_test的不处理
        int whereIndex = StringUtils.indexOfIgnoreCase(boundSql, Constants.WHERE);
        if(whereIndex != -1 ){
            String searchCondition = boundSql.substring(whereIndex,boundSql.length() - 1);
            if(StringUtils.indexOfIgnoreCase(searchCondition, "is_test") != -1){
                return false;
            }
        }
        // 根据id查询的方法不处理
        if(id.endsWith("selectById") || id.endsWith("selectBatchIds") || id.endsWith("selectOne")){
            return false;
        }
        String tableName = getTableName(boundSql);
        if(StringUtils.isBlank(tableName)){
            return false;
        }
        // 是否是需要忽略的表名
        List<String> ignoreTableList =  isTestFlagService.getTestIgnoreFieldTableNameList();
        if(CollectionUtils.isNotEmpty(ignoreTableList)){
            for(String ignoreTableName : ignoreTableList){
                if(tableName.startsWith(ignoreTableName)){
                    return false;
                }
            }
        }
        // 包含is_test字段的表才需要拼接查询条件
        if(!isContainTestColumn(executor,tableName)){
            log.warn("table not contains is_test field,table_name:{}",tableName);
            return false;
        }
        return true;
    }

    /**
     * 测试权限sql拼接
     * 有测试权限的人员可以查看全部数据,否则只能查看正式数据
     * @param buildSql 原SQL
     * @return 拼接后的SQL
     */
    private String addTestFlag(String buildSql) {
        Boolean isTest = Boolean.FALSE;
        AuthSystemLoginVo currentUser = CurrentUserContext.getCurrentUser();
        if (currentUser != null && currentUser.getIsTest() != null && currentUser.getIsTest() == YnEnum.YES.getCode()) {
            isTest = Boolean.TRUE;
        }
        if (isTest) {
            buildSql = this.insertSqlCondition(buildSql, TEST_FLAG_FIELD + " = 1");
        } else {
            buildSql = this.insertSqlCondition(buildSql, TEST_FLAG_FIELD + " = 0");
        }
        return buildSql;
    }

    /**
     * 将指定的sql表达式插入到sql中,插入到where后边
     * 
     * @param buildSql 原SQL
     * @param extraSqlCondition 指定的sql表达式
     * @return 插入后的SQL
     */
    private String insertSqlCondition(String buildSql, String extraSqlCondition) {
        // 为避免错误,
        if (StringUtils.isBlank(extraSqlCondition)) {
            return buildSql;
        }
        // 先定位“where”的位置
        int whereIndex = StringUtils.indexOfIgnoreCase(buildSql, Constants.WHERE);
        if (whereIndex != -1) {
            // 有“where”条件,只需要追加“sql表达式”即可
            buildSql = buildSql.substring(0, whereIndex + 5) + " (" + extraSqlCondition + ") and "
                + buildSql.substring(whereIndex + 5);
        } else {
            // 无“where”条件,需要追加“where”和“sql表达式”
            int index = this.indexOfInsertIndex(buildSql);
            if (index != -1) {
                buildSql =
                    buildSql.substring(0, index) + " where (" + extraSqlCondition + ") " + buildSql.substring(index);
            }
        }
        return buildSql;
    }

    /**
     * 定位到需要插入WHERE条件的索引处,定位失败则返回-1
     * 
     * @param buildSql 待执行的sql
     * @return 定位失败则返回-1
     */
    private int indexOfInsertIndex(String buildSql) {
        int from = buildSql.indexOf("FROM");
        if (from == -1) {
            log.warn("执行sql无法定位到FROM语句,sql=【{}】", buildSql);
            return from;
        }
        // 表名前后都为空格,标记是否已经读到了表名
        // 如果读到了表名,且当前字符为空格,则表示定位到了from之后适合插入WHERE语句的索引位置
        boolean tableRead = false;
        for (int i = from + 4; i < buildSql.length(); i++) {
            char c = buildSql.charAt(i);
            if (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
                if (tableRead) {
                    return i;
                }
            } else {
                tableRead = true;
            }
        }
        return -1;
    }

    private String getTableName(String buildSql){
        try {
            Statement statement = CCJSqlParserUtil.parse(buildSql);
            if (statement instanceof Select) {
                Select selectStatement = (Select) statement;
                TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
                List<String> tableList = tablesNamesFinder.getTableList(selectStatement);
                if(CollectionUtils.isEmpty(tableList) || tableList.size() > 1){
                    return "";
                }
                return tableList.get(0);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return "";
    }

    private boolean isContainTestColumn(Executor executor,String tableName) {
        try {
            Connection connection = executor.getTransaction().getConnection();
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet columns = metaData.getColumns(null, null, tableName, IbmpInnerInterceptor.TEST_FLAG_FIELD);
            return columns.next();
        } catch (Exception e) {
            return false;
        }
    }

    private static ApplicationContext applicationContext;

    public <T> T getBean(Class<T> clazz) {
        return IbmpInnerInterceptor.applicationContext.getBean(clazz);
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        IbmpInnerInterceptor.applicationContext = applicationContext;
    }
}

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值