Druid拦截sql语句,实现在添加一个查询条件

Druid拦截sql语句,实现在添加一个查询条件

这里就不详细描述原理了。

首先需要重写一下FilterEventAdapter里的connection_prepareStatement方法,然后对sql进行解析,根据不同情况添加where查询条件。

package com.spek.base.filter;

import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.ConnectionProxy;
import com.alibaba.druid.proxy.jdbc.PreparedStatementProxy;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.fastjson.JSON;
import com.spek.base.utils.SessionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

/**
 * druidDataSource 拦截器,拦截sql,添加租户ID条件
 *
 * @author wuya
 *
 */
public class MysqlFilter extends FilterEventAdapter {

    private static final Logger LOG = LoggerFactory.getLogger(MysqlFilter.class);

    private static final String TABLE_FIELD_TENANT_ID = "tenant_id";

    private static final String MYSQL_STRING = "mysql";

    private static final List<String> NOT_HAVE_TENANT_ID_TABLE_LIST = Arrays.asList("bi_bank",
        "bi_support_charge_detail");

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
        throws SQLException {
        LOG.info("original sql = {}", sql);
        try {
            // 解析sql
            MySqlStatementParser parser = new MySqlStatementParser(sql);
            SQLStatement stmt = parser.parseStatement();
            if (stmt instanceof SQLSelectStatement) {
                SQLSelect sqlSelect = ((SQLSelectStatement)stmt).getSelect();
                if (sqlSelect.getQuery() instanceof SQLUnionQuery) {
                    SQLUnionQuery unionQuery = (SQLUnionQuery)sqlSelect.getQuery();
                    sql = doUnionSelect(unionQuery);
                } else {
                    sql = doSelectSql(sql, (MySqlSelectQueryBlock)sqlSelect.getQueryBlock());
                }
            } else if (stmt instanceof MySqlUpdateStatement) {
                MySqlUpdateStatement update = (MySqlUpdateStatement)stmt;
                sql = doUpdateSql(sql, update);
            } else if (stmt instanceof MySqlDeleteStatement) {
                MySqlDeleteStatement delete = (MySqlDeleteStatement)stmt;
                sql = doDeleteSql(sql, delete);
            }
        } catch (Exception e) {
            LOG.error("deal self filter sql error {}", e);
        }
        LOG.info("new sql = {}", sql);
        return super.connection_prepareStatement(chain, connection, sql);
    }

    /**
     * 处理union查询语句
     * @param unionQuery 语句
     * @return 处理结果
     */
    private String doUnionSelect(SQLUnionQuery unionQuery) {
        SQLSelectQuery left = unionQuery.getLeft();
        SQLSelectQuery right = unionQuery.getRight();
        if (left instanceof SQLUnionQuery) {
            doUnionSelect((SQLUnionQuery)left);
        } else {
            doSelectSql(String.valueOf(left), (MySqlSelectQueryBlock)left);
        }
        if (right instanceof SQLUnionQuery) {
            doUnionSelect((SQLUnionQuery)right);
        } else {
            doSelectSql(String.valueOf(right), (MySqlSelectQueryBlock)right);
        }
        return String.valueOf(unionQuery);
    }

    /**
     * 处理查询语句
     *
     * @param sql SQL
     * @return 处理后的SQL
     */
    private String doSelectSql(String sql, MySqlSelectQueryBlock select) {
        // 获取where对象
        SQLExpr where = select.getWhere();
        List<SQLSelectItem> selectList = select.getSelectList();
        // 遍历查询的字段,如果查询字段中有子查询 则加上租户ID查询条件
        selectList.forEach(e -> {
            if (e.getExpr() instanceof SQLQueryExpr) {
                SQLQueryExpr expr = (SQLQueryExpr)e.getExpr();
                String newFieldSql = doSelectSql(String.valueOf(expr), (MySqlSelectQueryBlock)expr.getSubQuery().getQueryBlock());
                SQLExpr subSelect = SQLUtils.toMySqlExpr(newFieldSql);
                e.setExpr(subSelect);
                LOG.info("sql select field have subQuery = {}", newFieldSql);
            }
        });
        // 获取所查询的表
        SQLTableSource from = select.getFrom();
        // 如果from语句是子查询
        if (from instanceof SQLSubqueryTableSource) {
            String fromString = String.valueOf(from);
            SQLSubqueryTableSource subqueryTableSource = (SQLSubqueryTableSource)from;
            String subQuery = doSelectSql(fromString, (MySqlSelectQueryBlock)subqueryTableSource.getSelect().getQueryBlock());
            LOG.info("sql from have subQuery = {}", subQuery);
            SQLSelect sqlSelectBySql = getSqlSelectBySql(subQuery);
            ((SQLSubqueryTableSource)from).setSelect(sqlSelectBySql);
            select.setWhere(getNewWhereCondition(select, where, sql, from));
        }
        // 如果from语句是关联查询
        if (from instanceof SQLJoinTableSource) {
            SQLJoinTableSource joinFrom = (SQLJoinTableSource)from;
            SQLTableSource left = joinFrom.getLeft();
            SQLTableSource right = joinFrom.getRight();
            setTableSourceNewSql(left);
            setTableSourceNewSql(right);
        }
        select.setWhere(getNewWhereCondition(select, where, sql, from));
        return select.toString();
    }

    /**
     * 处理更新语句
     *
     * @param sql sql语句
     * @param stmt 解析的语句
     * @return 修改的后的sql
     */
    private String doUpdateSql(String sql, SQLStatement stmt) {
        MySqlUpdateStatement update = (MySqlUpdateStatement)stmt;
        SQLExpr where = update.getWhere();
        // 拼接where条件
        update.setWhere(getNewWhereCondition(null, where, sql, update.getTableSource()));
        return update.toString();
    }

    /**
     * 处理delete语句
     *
     * @param sql sql语句
     * @param stmt 解析的语句
     * @return 修改的后的sql
     */
    private String doDeleteSql(String sql, SQLStatement stmt) {
        MySqlDeleteStatement delete = (MySqlDeleteStatement)stmt;
        SQLExpr where = delete.getWhere();
        // 拼接where条件
        delete.setWhere(getNewWhereCondition(null, where, sql, delete.getTableSource()));
        return delete.toString();
    }

    /**
     * 添加where条件
     *
     * @param where where语句
     * @return 修改后的where条件
     */
    private SQLExpr getNewWhereCondition(MySqlSelectQueryBlock select, SQLExpr where, String sql,
        SQLTableSource tableSource) {
        // 如果where中包含子查询
        if (where instanceof SQLInSubQueryExpr) {
            SQLSelect subSelect = ((SQLInSubQueryExpr)where).subQuery;
            // 获取子查询语句
            String subQuery = String.valueOf(subSelect);
            // 处理子查询语句
            String newSubQuery = doSelectSql(subQuery, (MySqlSelectQueryBlock)subSelect.getQueryBlock());
            SQLSelect sqlSelectBySql = getSqlSelectBySql(newSubQuery);
            ((SQLInSubQueryExpr)where).setSubQuery(sqlSelectBySql);
        }
        SQLBinaryOpExpr binaryOpExprWhere = new SQLBinaryOpExpr(MYSQL_STRING);
        List<SourceFromInfo> tableNameList = new ArrayList<>();
        getTableNames(select, tableSource, tableNameList);
        if (CollectionUtils.isEmpty(tableNameList)) {
            return where;
        }
        // 根据多个表名获取拼接条件
        SQLBinaryOpExpr conditionByTableName = getWhereConditionByTableList(tableNameList);
        LOG.info("get tableInfos = {}", JSON.toJSONString(tableNameList));
        // 没有需要添加的条件,直接返回
        if (ObjectUtils.isEmpty(conditionByTableName)) {
            return where;
        }
        // 没有where条件时 则返回需要添加的条件
        if (where == null) {
            return conditionByTableName;
        }
        binaryOpExprWhere.setLeft(conditionByTableName);
        binaryOpExprWhere.setOperator(SQLBinaryOperator.BooleanAnd);
        binaryOpExprWhere.setRight(where.clone());
        if (isTenantIdAndOrCondition(where)) {
            LOG.info("the sql contains or condition by tenant_id, sql = {}", sql);
        }
        return binaryOpExprWhere;
    }

    /**
     * 根据from语句得到的表名拼接条件
     *
     * @param tableNameList 表名列表
     * @return 拼接后的条件
     */
    private SQLBinaryOpExpr getWhereConditionByTableList(List<SourceFromInfo> tableNameList) {
        // 先过滤掉不需要添加条件的
        tableNameList =
            tableNameList.stream().filter(fromInfo -> fromInfo.isNeedAddCondition()).collect(Collectors.toList());
        if (CollectionUtils.isEmpty(tableNameList)) {
            return null;
        }
        SQLBinaryOpExpr allCondition = new SQLBinaryOpExpr(MYSQL_STRING);
        for (int i = 0; i < tableNameList.size(); i++) {
            SourceFromInfo tableNameInfo = tableNameList.get(i);
            SQLBinaryOpExpr thisTenantIdWhere = getTenantIdCondition(tableNameInfo);
            // 如果是最后一个且不是第一个则将当期table条件设置为右侧条件
            if (i > 0 && i == tableNameList.size() - 1) {
                allCondition.setOperator(SQLBinaryOperator.BooleanAnd);
                allCondition.setRight(thisTenantIdWhere);
                break;
            }
            // 如果是只有一个table 则直接设置最终条件为当期table条件
            if (tableNameList.size() == 1) {
                allCondition = thisTenantIdWhere;
                break;
            }
            if (allCondition.getLeft() == null) {
                allCondition.setLeft(thisTenantIdWhere);
            } else {
                SQLBinaryOpExpr condition = getAndCondition((SQLBinaryOpExpr)allCondition.getLeft(), thisTenantIdWhere);
                allCondition.setLeft(condition);
            }
        }
        return allCondition;
    }

    /**
     * 拼接and条件
     *
     * @param left 左侧条件
     * @param right 右侧条件
     * @return 拼接后的条件
     */
    private SQLBinaryOpExpr getAndCondition(SQLBinaryOpExpr left, SQLBinaryOpExpr right) {
        SQLBinaryOpExpr condition = new SQLBinaryOpExpr(MYSQL_STRING);
        condition.setLeft(left);
        condition.setOperator(SQLBinaryOperator.BooleanAnd);
        condition.setRight(right);
        return condition;
    }

    /**
     * 根据表信息拼接tenantId 条件
     *
     * @param tableNameInfo 表信息
     * @return 拼接后的条件
     */
    private SQLBinaryOpExpr getTenantIdCondition(SourceFromInfo tableNameInfo) {
        SQLBinaryOpExpr tenantIdWhere = new SQLBinaryOpExpr(MYSQL_STRING);
        int tenantId = SessionUtils.getCurrentTenantId();
        if (StringUtils.isEmpty(tableNameInfo.getAlias())) {
            // 拼接新的条件
            tenantIdWhere.setOperator(SQLBinaryOperator.Equality);
            tenantIdWhere.setLeft(new SQLIdentifierExpr(TABLE_FIELD_TENANT_ID));
            // 设置当前租户ID条件
            tenantIdWhere.setRight(new SQLIntegerExpr(tenantId));
        } else {
            // 拼接别名条件
            tenantIdWhere.setLeft(new SQLPropertyExpr(tableNameInfo.getAlias(), TABLE_FIELD_TENANT_ID));
            tenantIdWhere.setOperator(SQLBinaryOperator.Equality);
            tenantIdWhere.setRight(new SQLIntegerExpr(tenantId));
        }
        return tenantIdWhere;
    }

    /**
     * 查询所有的表信息
     *
     * @param select from语句对应的select语句
     * @param tableSource from语句
     * @param tableNameList sql中from语句中所有表信息
     */
    private void getTableNames(MySqlSelectQueryBlock select, SQLTableSource tableSource,
        List<SourceFromInfo> tableNameList) {
        // 子查询
        if (tableSource instanceof SQLSubqueryTableSource) {
            SourceFromInfo fromInfo = new SourceFromInfo();
            fromInfo.setSubQuery(true);
            SQLSubqueryTableSource subqueryTableSource = (SQLSubqueryTableSource)tableSource;
            // 设置别名
            fromInfo.setAlias(subqueryTableSource.getAlias());
            List<SQLSelectItem> selectList = select.getSelectList();
            Optional.ofNullable(selectList).filter(list -> !CollectionUtils.isEmpty(selectList)).map(list -> {
                list.forEach(item -> {
                    String itemString = String.valueOf(item);
                    // 如果查询字段中有tenant_id 字段则需要加条件 否则不用加
                    if (StringUtils.contains(itemString, TABLE_FIELD_TENANT_ID)) {
                        fromInfo.setNeedAddCondition(true);
                        return;
                    }
                });
                return list;
            });
            tableNameList.add(fromInfo);
        }
        // 连接查询
        if (tableSource instanceof SQLJoinTableSource) {
            SQLJoinTableSource joinSource = (SQLJoinTableSource)tableSource;
            SQLTableSource left = joinSource.getLeft();
            SQLTableSource right = joinSource.getRight();
            // 子查询则递归获取
            if (left instanceof SQLSubqueryTableSource) {
                getTableNames((MySqlSelectQueryBlock)((SQLSubqueryTableSource)left).getSelect().getQuery(), left,
                    tableNameList);
            }
            // 子查询则递归获取
            if (right instanceof SQLSubqueryTableSource) {
                getTableNames((MySqlSelectQueryBlock)((SQLSubqueryTableSource)right).getSelect().getQuery(), right,
                    tableNameList);
            }
            // 连接查询 左边是单表
            if (left instanceof SQLExprTableSource) {
                addOnlyTable(left, tableNameList);
            }
            // 连接查询 右边是单表
            if (right instanceof SQLExprTableSource) {
                addOnlyTable(right, tableNameList);
            }
            // 连接查询 左边还是连接查询 则递归继续获取表名
            if (left instanceof SQLJoinTableSource) {
                getTableNames(null, left, tableNameList);
            }
            // 连接查询 右边还是连接查询 则递归继续获取表名
            if (right instanceof SQLJoinTableSource) {
                getTableNames(null, right, tableNameList);
            }
        }
        // 普通表查询
        if (tableSource instanceof SQLExprTableSource) {
            addOnlyTable(tableSource, tableNameList);
        }
    }

    /**
     * 如果当前from语句只有单表,则添加到list中
     *
     * @param tableSource from语句
     * @param tableNameList 表信息list
     */
    private void addOnlyTable(SQLTableSource tableSource, List<SourceFromInfo> tableNameList) {
        SourceFromInfo fromInfo = new SourceFromInfo();
        // 普通表查询
        String tableName = String.valueOf(tableSource);
        fromInfo.setTableName(tableName);
        fromInfo.setAlias(tableSource.getAlias());
        if (!NOT_HAVE_TENANT_ID_TABLE_LIST.contains(tableName)) {
            fromInfo.setNeedAddCondition(true);
        }
        tableNameList.add(fromInfo);
    }

    /**
     * 条件中是否为 and or 表达式
     *
     * @param where sql中where条件语句
     * @return 判断结果
     */
    private boolean isContainsTenantIdCondition(SQLExpr where) {
        if (!(where instanceof SQLBinaryOpExpr)) {
            return false;
        }
        SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr)where;
        SQLExpr left = binaryOpExpr.getLeft();
        SQLExpr right = binaryOpExpr.getRight();
        // 是否包含tenant_id 为查询条件
        if (!(left instanceof SQLBinaryOpExpr) && !(right instanceof SQLBinaryOpExpr)
            && (TABLE_FIELD_TENANT_ID.equals(String.valueOf(left))
            || TABLE_FIELD_TENANT_ID.equals(String.valueOf(right)))) {
            return true;
        }
        return false;
    }

    /**
     * 是否包括 or tenant_id = xx的条件
     *
     * @param where sql中where条件语句
     * @return 判断结果
     */
    private boolean isTenantIdAndOrCondition(SQLExpr where) {
        if (!(where instanceof SQLBinaryOpExpr)) {
            return false;
        }
        SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr)where;
        if ((isContainsTenantIdCondition(binaryOpExpr.getLeft())
            || isContainsTenantIdCondition(binaryOpExpr.getRight()))
            && "BooleanOr".equals(String.valueOf(binaryOpExpr.getOperator()))) {
            return true;
        }
        return isTenantIdAndOrCondition(binaryOpExpr.getLeft()) || isTenantIdAndOrCondition(binaryOpExpr.getRight());
    }

    /**
     * from语句是子查询的 处理子查询 并更新from语句
     *
     * @param tableSource from语句
     */
    private void setTableSourceNewSql(SQLTableSource tableSource) {
        if (!(tableSource instanceof SQLSubqueryTableSource)) {
            return;
        }
        SQLSubqueryTableSource subqueryTableSource = (SQLSubqueryTableSource)tableSource;
        String leftSubQueryString = String.valueOf(subqueryTableSource.getSelect());
        String newLeftSubQueryString = doSelectSql(leftSubQueryString, (MySqlSelectQueryBlock)subqueryTableSource.getSelect().getQueryBlock());
        SQLSelect sqlselect = getSqlSelectBySql(newLeftSubQueryString);
        subqueryTableSource.setSelect(sqlselect);
    }

    /**
     * 将String类型select sql语句转化为SQLSelect对象
     *
     * @param sql 查询SQL语句
     * @return 转化后的对象实体
     */
    private SQLSelect getSqlSelectBySql(String sql) {
        SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, MYSQL_STRING);
        List<SQLStatement> parseStatementList = parser.parseStatementList();
        if (CollectionUtils.isEmpty(parseStatementList)) {
            return null;
        }
        SQLSelectStatement sstmt = (SQLSelectStatement)parseStatementList.get(0);
        SQLSelect sqlselect = sstmt.getSelect();
        return sqlselect;
    }
}

接着需要一个数据源配置项:

package com.spek.base.filter;

import java.util.ArrayList;
import java.util.List;

import javax.sql.DataSource;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import com.alibaba.druid.filter.Filter;
import com.alibaba.druid.filter.stat.StatFilter;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.wall.WallFilter;

@Configuration
public class DynamicDataSourceConfig {

    private DruidDataSource createCustomSource() {
        DruidDataSource dataSource = new DruidDataSource();
        List<Filter> filters = new ArrayList<Filter>();
        filters.add(new WallFilter());
        //注入拦截器
        filters.add(new MysqlFilter());
        filters.add(new StatFilter());
        dataSource.setProxyFilters(filters);
        return dataSource;
    }
    
    @Bean
    @ConfigurationProperties(prefix = "spring.datasource")
    public DataSource dataSource(){
        return createCustomSource();
    }
}

最后需要在启动类里面加上@Import({DynamicDataSourceConfig.class}) 注解

  • 2
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值