Mybatis多租户插件

本插件基于数据库字段进行数据的隔离,字段为tenantId

1.在原有项目的所有接口中加入租户字段,一个一个改动工作量巨大,考虑采用过滤器拦截请求把请求头参数中的参数织入到sql中的where条件中。

2.请求拦截器

/**
 * 拦截器配置
 */
@Configuration
public class WebHeadFilter {

    @Bean
    public FilterRegistrationBean modifyParametersFilter() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setFilter(new MyHeadFilter());
        registration.addUrlPatterns("/*");              // 拦截路径
        registration.setName("headParametersFilter"); // 拦截器名称
        registration.setOrder(1);                       // 顺序
        return registration;
    }

    /**
     * 自定义拦截器
     */
    class MyHeadFilter extends OncePerRequestFilter {
        @Override
        protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
            // tenantId信息处理
            HeadRequestWrapper headRequestWrapper = new HeadRequestWrapper(request);
            if (StringUtils.isBlank(headRequestWrapper.getHeader("tenantId"))) {
                headRequestWrapper.addHead("tenantId", "01b93ddbb36b4fe69c789595ab686597");
            }
            TenantInfoUtil.set(headRequestWrapper.getHeader("tenantId"));
            // finish
            filterChain.doFilter(headRequestWrapper, response);
        }

        @Override
        public void destroy() {
            TenantInfoUtil.remove();
            super.destroy();
        }
    }

    /**
     * 修改head信息
     */
    class HeadRequestWrapper extends HttpServletRequestWrapper {
        private final Map<String, String> headers;

        HeadRequestWrapper(HttpServletRequest request) {
            super(request);
            this.headers = new HashMap<>();
        }

        @Override
        public String getHeader(String name) {
            String headervalue = super.getHeader(name);
            if (headers.containsKey(name)) {
                headervalue = headers.get(name);
            }
            return headervalue;
        }

        @Override
        public Enumeration<String> getHeaderNames() {
            List<String> values = Collections.list(super.getHeaderNames());
            for (String value : headers.keySet()) {
                values.add(value);
            }
            return Collections.enumeration(values);
        }

        @Override
        public Enumeration<String> getHeaders(String name) {
            List<String> values = Collections.list(super.getHeaders(name));
            if (headers.containsKey(name)) {
                values = Arrays.asList(headers.get(name));
            }
            return Collections.enumeration(values);
        }

        public void addHead(String name, String value) {
            this.headers.put(name, value);
        }
    }

}

3.租户信息工具类

public class TenantInfoUtil {

    // 构造方法私有化
    private TenantInfoUtil(){}

    private static final ThreadLocal<String> context = new ThreadLocal<>();

    /**
     * 存放用户信息
     * @param tenantInfo
     */
    public static void set(String tenantInfo){
        context.set(tenantInfo);
    }

    /**
     * 获取租户信息
     * @return
     */
    public static String get(){
        return context.get();
    }

    /**
     * 清除当前线程内引用,防止内存泄漏
     */
    public static void remove(){
        context.remove();
    }
}

4.多租户插件

@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class MultiTenantPlugin implements Interceptor {

    /**
     * 当前数据库的方言
     */
    private String dialect;
    /**
     * 多租户字段名称
     */
    private String tenantIdField;

    /**
     * 需要识别多租户字段的表名称列表
     */
    private Set<String> tableSet;
    /**
     * sql语句工具
     */
    private SqlConditionUtil sqlConditionUtil;


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        String tenantId = TenantInfoUtil.get();
        log.info("---tenantId:------->" + tenantId);
        //租户id为空时不做处理
        if (StringUtils.isBlank(tenantId)) {
            return invocation.proceed();
        }
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        //原始的boundsql
        log.info("------boundSql:-----" + boundSql.getSql());
        //把新sql设置到boundSql
        String newSql = addTenantCondition(boundSql.getSql(), tenantId);
        log.info("--------newSql:---" + newSql);
        //把新sql设置到boundSql
        metaObject.setValue("delegate.boundSql.sql", newSql);
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        dialect = properties.getProperty("dialect");
        if (StringUtils.isBlank(dialect)) {
            throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
        }
        tenantIdField = properties.getProperty("tenantIdField");
        if (StringUtils.isBlank(tenantIdField)) {
            throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
        }
        String tableNames = properties.getProperty("tableNames");
        if (!StringUtils.isBlank(tableNames)) {
            tableSet = new HashSet<String>(Arrays.asList(StringUtils.split(tableNames, ",")));
        }
        if (tableSet == null) {
            throw new IllegalArgumentException("MultiTenantPlugin tableNames must have");
        }
        /**
         * 多租户条件字段决策器
         */
        ITableFieldConditionDecision conditionDecision = new ITableFieldConditionDecision() {
            @Override
            public boolean isAllowNullValue() {
                return false;
            }

            @Override
            public boolean adjudge(String tableName, String fieldName) {
                if (tableSet != null && tableSet.contains(tableName)) {
                    return true;
                }
                return false;
            }
        };
        sqlConditionUtil = new SqlConditionUtil(conditionDecision);
    }


    /**
     * 给sql语句where添加租户id过滤条件
     *
     * @param sql      要添加过滤条件的sql语句
     * @param tenantId 当前的租户id
     * @return 添加条件后的sql语句
     */
    private String addTenantCondition(String sql, String tenantId) {
        if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) {
            return sql;
        }
        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, dialect);
        if (statementList == null || statementList.size() == 0) {
            return sql;
        }

        SQLStatement sqlStatement = statementList.get(0);
        sqlConditionUtil.addStatementCondition(sqlStatement, tenantIdField, tenantId);
        return SQLUtils.toSQLString(statementList, dialect);
    }
}

5.sql语句的拼接工具类

public class SqlConditionUtil {

    private ITableFieldConditionDecision conditionDecision;

    public SqlConditionUtil(ITableFieldConditionDecision conditionDecision) {
        this.conditionDecision = conditionDecision;
    }

    /**
     * 为sql'语句添加指定where条件
     *
     * @param sqlStatement
     * @param fieldName
     * @param fieldValue
     */
    public void addStatementCondition(SQLStatement sqlStatement, String fieldName, String fieldValue) {
        if (sqlStatement instanceof SQLSelectStatement) {
            SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
            addSelectStatementCondition(queryObject, queryObject.getFrom(), fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLUpdateStatement) {
            SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
            addUpdateStatementCondition(updateStatement, fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLDeleteStatement) {
            SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
            addDeleteStatementCondition(deleteStatement, fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLInsertStatement) {
            SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
            addInsertStatementCondition(insertStatement, fieldName, fieldValue);
        }
    }

    /**
     * 为insert语句添加where条件
     *
     * @param insertStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addInsertStatementCondition(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
        if (insertStatement != null) {
            SQLInsertInto sqlInsertInto = insertStatement;
            SQLSelect sqlSelect = sqlInsertInto.getQuery();
            if (sqlSelect != null) {
                SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
                addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
            } else {
                //处理插入是没有包含字段得情况
                if (!sqlInsertInto.getColumns().stream().anyMatch(e -> fieldName.equalsIgnoreCase(e.clone().toString()))) {
                    sqlInsertInto.getColumns().add(new SQLIdentifierExpr(fieldName));
                    sqlInsertInto.getValuesList().get(0).addValue(new SQLCharExpr(fieldValue));
                }
            }
        }
    }


    /**
     * 为delete语句添加where条件
     *
     * @param deleteStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, String fieldName, String fieldValue) {
        SQLExpr where = deleteStatement.getWhere();
        //添加子查询中的where条件
        addSQLExprCondition(where, fieldName, fieldValue);

        SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
                deleteStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
        deleteStatement.setWhere(newCondition);

    }

    /**
     * where中添加指定筛选条件
     *
     * @param where      源where条件
     * @param fieldName
     * @param fieldValue
     */
    private void addSQLExprCondition(SQLExpr where, String fieldName, String fieldValue) {
        if (where instanceof SQLInSubQueryExpr) {
            SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
            SQLSelect subSelectObject = inWhere.getSubQuery();
            SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
            addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
        } else if (where instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
            SQLExpr left = opExpr.getLeft();
            SQLExpr right = opExpr.getRight();
            addSQLExprCondition(left, fieldName, fieldValue);
            addSQLExprCondition(right, fieldName, fieldValue);
        } else if (where instanceof SQLQueryExpr) {
            SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
            addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
        }
    }

    /**
     * 为update语句添加where条件
     *
     * @param updateStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, String fieldName, String fieldValue) {
        SQLExpr where = updateStatement.getWhere();
        //添加子查询中的where条件
        addSQLExprCondition(where, fieldName, fieldValue);
        SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
                updateStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
        updateStatement.setWhere(newCondition);
    }

    /**
     * 给一个查询对象添加一个where条件
     *
     * @param queryObject
     * @param fieldName
     * @param fieldValue
     */
    private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, String fieldName, String fieldValue) {
        if (StringUtils.isBlank(fieldName) || from == null || queryObject == null) {
            return;
        }

        SQLExpr originCondition = queryObject.getWhere();
        if (from instanceof SQLExprTableSource) {
            String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
            String alias = from.getAlias();
            SQLExpr newCondition = newEqualityCondition(tableName, alias, fieldName, fieldValue, originCondition);
            queryObject.setWhere(newCondition);
        } else if (from instanceof SQLJoinTableSource) {
            SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
            SQLTableSource left = joinObject.getLeft();
            SQLTableSource right = joinObject.getRight();

            addSelectStatementCondition(queryObject, left, fieldName, fieldValue);
            addSelectStatementCondition(queryObject, right, fieldName, fieldValue);

        } else if (from instanceof SQLSubqueryTableSource) {
            SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
            SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
            addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
        } else if (from instanceof SQLUnionQueryTableSource) {
            SQLUnionQueryTableSource union = (SQLUnionQueryTableSource) from;
            SQLUnionQuery sqlUnionQuery = union.getUnion();
            //这里判断查询类型
            addSelectStatementConditionUnion(queryObject, sqlUnionQuery , fieldName, fieldValue);

        } else {
            throw new ServiceException("sql增强异常");
        }
    }

    /**
     * 拼接union查询的租户字段
     * @param queryObject
     * @param sqlUnionQuery
     * @param fieldName
     * @param fieldValue
     */
    private void addSelectStatementConditionUnion(SQLSelectQueryBlock queryObject, SQLUnionQuery sqlUnionQuery, String fieldName, String fieldValue) {
        if(sqlUnionQuery.getLeft() instanceof SQLUnionQuery) {
            SQLUnionQuery temQuery= (SQLUnionQuery)sqlUnionQuery.getLeft();
            addSelectStatementConditionUnion(queryObject, temQuery , fieldName, fieldValue);
        }
        if(sqlUnionQuery.getLeft() instanceof SQLSelectQueryBlock) {
            SQLSelectQueryBlock left= (SQLSelectQueryBlock)sqlUnionQuery.getLeft();
            addSelectStatementCondition(left, left.getFrom() , fieldName, fieldValue);
        }
        if(sqlUnionQuery.getRight() instanceof SQLSelectQueryBlock) {
            SQLSelectQueryBlock right = (SQLSelectQueryBlock) sqlUnionQuery.getRight();
            addSelectStatementCondition(right, right.getFrom(), fieldName, fieldValue);
        }

    }

    /**
     * 根据原来的condition创建一个新的condition
     *
     * @param tableName       表名称
     * @param tableAlias      表别名
     * @param fieldName
     * @param fieldValue
     * @param originCondition
     * @return
     */
    private SQLExpr newEqualityCondition(String tableName, String tableAlias, String fieldName, String fieldValue, SQLExpr originCondition) {
        //如果不需要设置条件
        if (!conditionDecision.adjudge(tableName, fieldName)) {
            return originCondition;
        }
        //如果条件字段不允许为空
        if (fieldValue == null && !conditionDecision.isAllowNullValue()) {
            return originCondition;
        }

        String filedName = StringUtils.isBlank(tableAlias) ? fieldName : tableAlias + "." + fieldName;
        SQLExpr condition = new SQLBinaryOpExpr(new SQLIdentifierExpr(filedName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
        return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, condition, false, originCondition);
    }


    public static void main(String[] args) {
//        String sql = "select * from user s  ";
//        String sql = "select * from user s where s.name='333'";
//        String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
//        String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";

//        String sql = "update user set name=? where id =(select id from user s)";
//        String sql = "delete from user where id = ( select id from user s )";

//        String sql = "INSERT INTO deleted_organization  ( id,origin_id,org_name,org_type,org_level,org_order,super_org_id,org_number,email,company_phone,is_structure,is_deleted,tag,data_source_flag,create_user_name,update_user_name,create_time,update_time,delete_time, tenant_id) VALUES( 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,8) ";
        String sql = "SELECT count(0) FROM ((SELECT DISTINCT data_source_flag FROM t_account) UNION (SELECT DISTINCT data_source_flag FROM t_asset) UNION (SELECT DISTINCT data_source_flag FROM t_business_system) UNION (SELECT DISTINCT data_source_flag FROM t_component) UNION (SELECT DISTINCT data_source_flag FROM t_component_account) UNION (SELECT DISTINCT data_source_flag FROM t_employee) UNION (SELECT DISTINCT data_source_flag FROM t_intranet_ip_network_segment) UNION (SELECT DISTINCT data_source_flag FROM t_ip) UNION (SELECT DISTINCT data_source_flag FROM t_leak) UNION (SELECT DISTINCT data_source_flag FROM t_organization) UNION (SELECT DISTINCT data_source_flag FROM t_security_domain) UNION (SELECT DISTINCT data_source_flag FROM t_virus)) AS c LEFT JOIN d_enum e ON c.data_source_flag = e.code WHERE c.data_source_flag IS NOT NULL ";
//        String sql = "SELECT DISTINCT data_source_flag FROM t_account UNION SELECT DISTINCT data_source_flag FROM t_asset";
//        String sql = "SELECT DISTINCT data_source_flag FROM t_account";


//        String sql = "select u.*,g.name from user u join (select * from user_group g  join user_role r on g.role_code=r.code  ) g on u.groupId=g.groupId where u.name='123'";
        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
        SQLStatement sqlStatement = statementList.get(0);
        //决策器定义
        SqlConditionUtil helper = new SqlConditionUtil(new ITableFieldConditionDecision() {
            @Override
            public boolean adjudge(String tableName, String fieldName) {
                return true;
            }

            @Override
            public boolean isAllowNullValue() {
                return false;
            }
        });
        //添加多租户条件,tenant_id是字段,yay是筛选值
        helper.addStatementCondition(sqlStatement, "tenant_id", "yay");
        System.out.println("源sql:" + sql);
        System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.POSTGRESQL));
    }


}

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

焱童鞋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值