MyBatis-Plus - 自定义租户拦截器,一定要这样吗?!

今天看到如何重写“自定义租户拦截器”,因为用过的人都知道,官方提供的自定义入口只是针对表名进行拦截,所以无法扩展更多自定义做的事情。

/**
 * @author Lux Sun
 * @date 2023/7/18
 */
@Component
public class MyTenantHandler implements TenantHandler {

	@Override
	public Expression getTenantId(boolean select) {
		TenantInfo tenantInfo = TenantContext.get();
		if (tenantInfo == null) {
			return null;
		}
		String tenantId = tenantInfo.getId();
		return new StringValue(tenantId);
	}

	@Override
	public String getTenantIdColumn() {
		return "tenant_id";
	}

	@Override
	public boolean doTableFilter(String tableName) {
		if (StrUtil.equalsAny(tableName, "t_product")) {
			return true;
		}
		return false;
	}
}

于是,网上居然有人重写 Mybatis-Plus 租户拦截器,代码如下(需要重写的类)

/*
 * Copyright (c) 2011-2020, baomidou (jobob@qq.com).
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.baomidou.mybatisplus.extension.plugins.tenant;

import com.baomidou.mybatisplus.core.parser.AbstractJsqlParser;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.ValueListExpression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;

import java.util.List;

/**
 * 租户 SQL 解析器( TenantId 行级 )
 *
 * @author hubin
 * @since 2017-09-01
 */
@Data
@NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
@EqualsAndHashCode(callSuper = true)
public class TenantSqlParser extends AbstractJsqlParser {

    private TenantHandler tenantHandler;

    /**
     * select 语句处理
     */
    @Override
    public void processSelectBody(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelectBody(withItem.getSelectBody());
            }
        } else {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
                operationList.getSelects().forEach(this::processSelectBody);
            }
        }
    }

    /**
     * insert 语句处理
     */
    @Override
    public void processInsert(Insert insert) {
        if (tenantHandler.doTableFilter(insert.getTable().getName())) {
            // 过滤退出执行
            return;
        }
        insert.getColumns().add(new Column(tenantHandler.getTenantIdColumn()));
        if (insert.getSelect() != null) {
            processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
        } else if (insert.getItemsList() != null) {
            // fixed github pull/295
            ItemsList itemsList = insert.getItemsList();
            if (itemsList instanceof MultiExpressionList) {
                ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId(false)));
            } else {
                ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId(false));
            }
        } else {
            throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
        }
    }

    /**
     * update 语句处理
     */
    @Override
    public void processUpdate(Update update) {
        final Table table = update.getTable();
        if (tenantHandler.doTableFilter(table.getName())) {
            // 过滤退出执行
            return;
        }
        update.setWhere(this.andExpression(table, update.getWhere()));
    }

    /**
     * delete 语句处理
     */
    @Override
    public void processDelete(Delete delete) {
        if (tenantHandler.doTableFilter(delete.getTable().getName())) {
            // 过滤退出执行
            return;
        }
        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
    }

    /**
     * delete update 语句 where 处理
     */
    protected BinaryExpression andExpression(Table table, Expression where) {
        //获得where条件表达式
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(this.getAliasColumn(table));
        equalsTo.setRightExpression(tenantHandler.getTenantId(false));
        if (null != where) {
            if (where instanceof OrExpression) {
                return new AndExpression(equalsTo, new Parenthesis(where));
            } else {
                return new AndExpression(equalsTo, where);
            }
        }
        return equalsTo;
    }

    /**
     * 处理 PlainSelect
     */
    protected void processPlainSelect(PlainSelect plainSelect) {
        processPlainSelect(plainSelect, false);
    }

    /**
     * 处理 PlainSelect
     *
     * @param plainSelect ignore
     * @param addColumn   是否添加租户列,insert into select语句中需要
     */
    protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
        FromItem fromItem = plainSelect.getFromItem();
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            if (!tenantHandler.doTableFilter(fromTable.getName())) {
                //#1186 github
                plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
                if (addColumn) {
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantHandler.getTenantIdColumn())));
                }
            }
        } else {
            processFromItem(fromItem);
        }
        List<Join> joins = plainSelect.getJoins();
        if (joins != null && joins.size() > 0) {
            joins.forEach(j -> {
                processJoin(j);
                processFromItem(j.getRightItem());
            });
        }
    }

    /**
     * 处理子查询等
     */
    protected void processFromItem(FromItem fromItem) {
        if (fromItem instanceof SubJoin) {
            SubJoin subJoin = (SubJoin) fromItem;
            if (subJoin.getJoinList() != null) {
                subJoin.getJoinList().forEach(this::processJoin);
            }
            if (subJoin.getLeft() != null) {
                processFromItem(subJoin.getLeft());
            }
        } else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelectBody(subSelect.getSelectBody());
            }
        } else if (fromItem instanceof ValuesList) {
            logger.debug("Perform a subquery, if you do not give us feedback");
        } else if (fromItem instanceof LateralSubSelect) {
            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
            if (lateralSubSelect.getSubSelect() != null) {
                SubSelect subSelect = lateralSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
                }
            }
        }
    }

    /**
     * 处理联接语句
     */
    protected void processJoin(Join join) {
        if (join.getRightItem() instanceof Table) {
            Table fromTable = (Table) join.getRightItem();
            if (this.tenantHandler.doTableFilter(fromTable.getName())) {
                // 过滤退出执行
                return;
            }
            join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
        }
    }

    /**
     * 处理条件:
     * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
     * 默认tenantId的表达式: LongValue(1)这种依旧支持
     */
    protected Expression builderExpression(Expression currentExpression, Table table) {
        final Expression tenantExpression = tenantHandler.getTenantId(true);
        Expression appendExpression = this.processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
        if (currentExpression == null) {
            return appendExpression;
        }
        if (currentExpression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
            doExpression(binaryExpression.getLeftExpression());
            doExpression(binaryExpression.getRightExpression());
        } else if (currentExpression instanceof InExpression) {
            InExpression inExp = (InExpression) currentExpression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelectBody(((SubSelect) rightItems).getSelectBody());
            }
        }
        if (currentExpression instanceof OrExpression) {
            return new AndExpression(new Parenthesis(currentExpression), appendExpression);
        } else {
            return new AndExpression(currentExpression, appendExpression);
        }
    }

    protected void doExpression(Expression expression) {
        if (expression instanceof FromItem) {
            processFromItem((FromItem) expression);
        } else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelectBody(((SubSelect) rightItems).getSelectBody());
            }
        }
    }

    /**
     * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
     * select a.id, b.name
     * from a
     * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
     *
     * @param expression
     * @param table
     * @return 加上别名的多租户字段表达式
     */
    protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
        Expression target;
        if (expression instanceof ValueListExpression) {
            InExpression inExpression = new InExpression();
            inExpression.setLeftExpression(this.getAliasColumn(table));
            inExpression.setRightItemsList(((ValueListExpression) expression).getExpressionList());
            target = inExpression;
        } else {
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(this.getAliasColumn(table));
            equalsTo.setRightExpression(expression);
            target = equalsTo;
        }
        return target;
    }

    /**
     * 租户字段别名设置
     * <p>tenantId 或 tableAlias.tenantId</p>
     *
     * @param table 表对象
     * @return 字段
     */
    protected Column getAliasColumn(Table table) {
        StringBuilder column = new StringBuilder();
        if (table.getAlias() != null) {
            column.append(table.getAlias().getName()).append(StringPool.DOT);
        }
        column.append(tenantHandler.getTenantIdColumn());
        return new Column(column.toString());
    }
}

发现一顿操作猛如虎下来还各种编译报错“withItem.getSelectBody()”……于是就放弃了。

后来回过头去看之前那个官方提供的自定义类发现,其实最终还是靠 true or false 来控制是否使用租户拦截器功能,一开始被以为只能填写表名给蒙蔽了~

/**
 * @author Lux Sun
 * @date 2023/7/18
 */
@Component
public class MyTenantHandler implements TenantHandler {

	@Override
	public Expression getTenantId(boolean select) {
		TenantInfo tenantInfo = TenantContext.get();
		if (tenantInfo == null) {
			return null;
		}
		String tenantId = tenantInfo.getId();
		return new StringValue(tenantId);
	}

	@Override
	public String getTenantIdColumn() {
		return "tenant_id";
	}

	@Override
	public boolean doTableFilter(String tableName) {
		Boolean filter = SqlParserContext.get();
		if (filter) {
			return true;
		}

		if (StrUtil.equalsAny(tableName, "t_product")) {
			return true;
		}
		return false;
	}
}
/**
 * @author Lux Sun
 * @date 2023/7/18
 */
public class SqlParserContext {

	private static final ThreadLocal<Boolean> CONTEXT = new ThreadLocal<>();

	public static void set(Boolean filter) {
		CONTEXT.set(filter);
	}

	public static Boolean get() {
		if (CONTEXT.get() == null) {
			set(false);
		}
		return CONTEXT.get();
	}

	public static void clear() {
		CONTEXT.remove();
	}
}

我给它自定义了上下文类,这样可以通过每个线程进行对本次请求进行操作,只要设置为 true 就可以跳过拦截功能,否则启动拦截功能,除非表名能命中,那么还是会进行跳过拦截操作,反之。

看一个案例,在需要用到的地方使用上下文设置为 true 即可

/**
 * 获取财务记录
 * @param id
 * @return
 */
@GetMapping("/finance/{id}")
public ResultVO<FinancePO> getFinance(@PathVariable String id) {
	SqlParserContext.set(true);
	FinancePO fice = ficeService.getById(id);
	return ResultVoUtil.buildSuccess(fice);
}

最后,别忘了还需要注册 MybatisConfig 类,注册即可使用啦~

/**
 * @author Lux Sun
 * @date 2023/7/18
 */
@EnableTransactionManagement
@Configuration
public class MyBatisConfig {

	@Resource
	private MyTenantHandler myTenantHandler;

	@Bean
	public PaginationInterceptor paginationInterceptor() {
		PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
		paginationInterceptor.setSqlParserList(Collections.singletonList(new TenantSqlParser().setTenantHandler(myTenantHandler)));
		return paginationInterceptor;
	}
}
Mybatis-Plus 是一个基于 Mybatis 的增强工具,它简化了 Mybatis 的开发流程,提供了许多实用的功能,如自动生成代码、分页查询、条件构造器、性能分析等。在 2021 年的 Mybatis-Plus 面试中,可能会涉及到以下问题: 1. Mybatis-Plus 的优点是什么? Mybatis-Plus 的优点包括:简化开发流程、提高开发效率、提供实用的功能、易于集成、支持多种数据库、性能优秀等。 2. Mybatis-Plus 的核心功能是什么? Mybatis-Plus 的核心功能包括:自动生成代码、分页查询、条件构造器、性能分析、多租户支持、逻辑删除等。 3. Mybatis-Plus 的代码生成器是什么?有什么作用? Mybatis-Plus 的代码生成器是一个可视化工具,可以根据数据库表结构自动生成实体类、Mapper 接口、Mapper XML 文件等代码,大大提高了开发效率。 4. Mybatis-Plus 的分页查询是如何实现的? Mybatis-Plus 的分页查询是通过 PageHelper 类实现的,它可以自动拦截 SQL 语句,将查询结果封装成 Page 对象,提供了丰富的分页查询方法。 5. Mybatis-Plus 的条件构造器是什么?有什么作用? Mybatis-Plus 的条件构造器是一个灵活的查询条件构造工具,可以根据不同的查询需求,动态生成 SQL 语句,支持多种查询条件,如等于、不等于、大于、小于、模糊查询等。 6. Mybatis-Plus 的性能分析是如何实现的? Mybatis-Plus 的性能分析是通过 PerformanceInterceptor 类实现的,它可以拦截 SQL 语句,统计 SQL 执行时间、执行次数、执行的 SQL 语句等信息,帮助开发者优化 SQL 语句的性能。 7. Mybatis-Plus 的多租户支持是什么?有什么作用? Mybatis-Plus 的多租户支持是指可以根据不同的租户,动态切换数据源,实现数据隔离的功能。它可以帮助开发者在多租户场景下,简化数据访问的流程,提高开发效率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

陆氪和他的那些代码

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值