今天看到如何重写“自定义租户拦截器”,因为用过的人都知道,官方提供的自定义入口只是针对表名进行拦截,所以无法扩展更多自定义做的事情。
/**
* @author Lux Sun
* @date 2023/7/18
*/
@Component
public class MyTenantHandler implements TenantHandler {
@Override
public Expression getTenantId(boolean select) {
TenantInfo tenantInfo = TenantContext.get();
if (tenantInfo == null) {
return null;
}
String tenantId = tenantInfo.getId();
return new StringValue(tenantId);
}
@Override
public String getTenantIdColumn() {
return "tenant_id";
}
@Override
public boolean doTableFilter(String tableName) {
if (StrUtil.equalsAny(tableName, "t_product")) {
return true;
}
return false;
}
}
于是,网上居然有人重写 Mybatis-Plus 租户拦截器,代码如下(需要重写的类)
/*
* Copyright (c) 2011-2020, baomidou (jobob@qq.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
* <p>
* https://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.baomidou.mybatisplus.extension.plugins.tenant;
import com.baomidou.mybatisplus.core.parser.AbstractJsqlParser;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.ValueListExpression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import java.util.List;
/**
* 租户 SQL 解析器( TenantId 行级 )
*
* @author hubin
* @since 2017-09-01
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
@EqualsAndHashCode(callSuper = true)
public class TenantSqlParser extends AbstractJsqlParser {
private TenantHandler tenantHandler;
/**
* select 语句处理
*/
@Override
public void processSelectBody(SelectBody selectBody) {
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
if (withItem.getSelectBody() != null) {
processSelectBody(withItem.getSelectBody());
}
} else {
SetOperationList operationList = (SetOperationList) selectBody;
if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
operationList.getSelects().forEach(this::processSelectBody);
}
}
}
/**
* insert 语句处理
*/
@Override
public void processInsert(Insert insert) {
if (tenantHandler.doTableFilter(insert.getTable().getName())) {
// 过滤退出执行
return;
}
insert.getColumns().add(new Column(tenantHandler.getTenantIdColumn()));
if (insert.getSelect() != null) {
processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
} else if (insert.getItemsList() != null) {
// fixed github pull/295
ItemsList itemsList = insert.getItemsList();
if (itemsList instanceof MultiExpressionList) {
((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId(false)));
} else {
((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId(false));
}
} else {
throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
}
}
/**
* update 语句处理
*/
@Override
public void processUpdate(Update update) {
final Table table = update.getTable();
if (tenantHandler.doTableFilter(table.getName())) {
// 过滤退出执行
return;
}
update.setWhere(this.andExpression(table, update.getWhere()));
}
/**
* delete 语句处理
*/
@Override
public void processDelete(Delete delete) {
if (tenantHandler.doTableFilter(delete.getTable().getName())) {
// 过滤退出执行
return;
}
delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
}
/**
* delete update 语句 where 处理
*/
protected BinaryExpression andExpression(Table table, Expression where) {
//获得where条件表达式
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(this.getAliasColumn(table));
equalsTo.setRightExpression(tenantHandler.getTenantId(false));
if (null != where) {
if (where instanceof OrExpression) {
return new AndExpression(equalsTo, new Parenthesis(where));
} else {
return new AndExpression(equalsTo, where);
}
}
return equalsTo;
}
/**
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
processPlainSelect(plainSelect, false);
}
/**
* 处理 PlainSelect
*
* @param plainSelect ignore
* @param addColumn 是否添加租户列,insert into select语句中需要
*/
protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
FromItem fromItem = plainSelect.getFromItem();
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
if (!tenantHandler.doTableFilter(fromTable.getName())) {
//#1186 github
plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
if (addColumn) {
plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantHandler.getTenantIdColumn())));
}
}
} else {
processFromItem(fromItem);
}
List<Join> joins = plainSelect.getJoins();
if (joins != null && joins.size() > 0) {
joins.forEach(j -> {
processJoin(j);
processFromItem(j.getRightItem());
});
}
}
/**
* 处理子查询等
*/
protected void processFromItem(FromItem fromItem) {
if (fromItem instanceof SubJoin) {
SubJoin subJoin = (SubJoin) fromItem;
if (subJoin.getJoinList() != null) {
subJoin.getJoinList().forEach(this::processJoin);
}
if (subJoin.getLeft() != null) {
processFromItem(subJoin.getLeft());
}
} else if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
} else if (fromItem instanceof ValuesList) {
logger.debug("Perform a subquery, if you do not give us feedback");
} else if (fromItem instanceof LateralSubSelect) {
LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
if (lateralSubSelect.getSubSelect() != null) {
SubSelect subSelect = lateralSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
}
}
}
/**
* 处理联接语句
*/
protected void processJoin(Join join) {
if (join.getRightItem() instanceof Table) {
Table fromTable = (Table) join.getRightItem();
if (this.tenantHandler.doTableFilter(fromTable.getName())) {
// 过滤退出执行
return;
}
join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
}
}
/**
* 处理条件:
* 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
* 默认tenantId的表达式: LongValue(1)这种依旧支持
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
final Expression tenantExpression = tenantHandler.getTenantId(true);
Expression appendExpression = this.processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
if (currentExpression == null) {
return appendExpression;
}
if (currentExpression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
doExpression(binaryExpression.getLeftExpression());
doExpression(binaryExpression.getRightExpression());
} else if (currentExpression instanceof InExpression) {
InExpression inExp = (InExpression) currentExpression;
ItemsList rightItems = inExp.getRightItemsList();
if (rightItems instanceof SubSelect) {
processSelectBody(((SubSelect) rightItems).getSelectBody());
}
}
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), appendExpression);
} else {
return new AndExpression(currentExpression, appendExpression);
}
}
protected void doExpression(Expression expression) {
if (expression instanceof FromItem) {
processFromItem((FromItem) expression);
} else if (expression instanceof InExpression) {
InExpression inExp = (InExpression) expression;
ItemsList rightItems = inExp.getRightItemsList();
if (rightItems instanceof SubSelect) {
processSelectBody(((SubSelect) rightItems).getSelectBody());
}
}
}
/**
* 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
* select a.id, b.name
* from a
* join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
*
* @param expression
* @param table
* @return 加上别名的多租户字段表达式
*/
protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
Expression target;
if (expression instanceof ValueListExpression) {
InExpression inExpression = new InExpression();
inExpression.setLeftExpression(this.getAliasColumn(table));
inExpression.setRightItemsList(((ValueListExpression) expression).getExpressionList());
target = inExpression;
} else {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(this.getAliasColumn(table));
equalsTo.setRightExpression(expression);
target = equalsTo;
}
return target;
}
/**
* 租户字段别名设置
* <p>tenantId 或 tableAlias.tenantId</p>
*
* @param table 表对象
* @return 字段
*/
protected Column getAliasColumn(Table table) {
StringBuilder column = new StringBuilder();
if (table.getAlias() != null) {
column.append(table.getAlias().getName()).append(StringPool.DOT);
}
column.append(tenantHandler.getTenantIdColumn());
return new Column(column.toString());
}
}
发现一顿操作猛如虎下来还各种编译报错“withItem.getSelectBody()”……于是就放弃了。
后来回过头去看之前那个官方提供的自定义类发现,其实最终还是靠 true or false 来控制是否使用租户拦截器功能,一开始被以为只能填写表名给蒙蔽了~
/**
* @author Lux Sun
* @date 2023/7/18
*/
@Component
public class MyTenantHandler implements TenantHandler {
@Override
public Expression getTenantId(boolean select) {
TenantInfo tenantInfo = TenantContext.get();
if (tenantInfo == null) {
return null;
}
String tenantId = tenantInfo.getId();
return new StringValue(tenantId);
}
@Override
public String getTenantIdColumn() {
return "tenant_id";
}
@Override
public boolean doTableFilter(String tableName) {
Boolean filter = SqlParserContext.get();
if (filter) {
return true;
}
if (StrUtil.equalsAny(tableName, "t_product")) {
return true;
}
return false;
}
}
/**
* @author Lux Sun
* @date 2023/7/18
*/
public class SqlParserContext {
private static final ThreadLocal<Boolean> CONTEXT = new ThreadLocal<>();
public static void set(Boolean filter) {
CONTEXT.set(filter);
}
public static Boolean get() {
if (CONTEXT.get() == null) {
set(false);
}
return CONTEXT.get();
}
public static void clear() {
CONTEXT.remove();
}
}
我给它自定义了上下文类,这样可以通过每个线程进行对本次请求进行操作,只要设置为 true 就可以跳过拦截功能,否则启动拦截功能,除非表名能命中,那么还是会进行跳过拦截操作,反之。
看一个案例,在需要用到的地方使用上下文设置为 true 即可
/**
* 获取财务记录
* @param id
* @return
*/
@GetMapping("/finance/{id}")
public ResultVO<FinancePO> getFinance(@PathVariable String id) {
SqlParserContext.set(true);
FinancePO fice = ficeService.getById(id);
return ResultVoUtil.buildSuccess(fice);
}
最后,别忘了还需要注册 MybatisConfig 类,注册即可使用啦~
/**
* @author Lux Sun
* @date 2023/7/18
*/
@EnableTransactionManagement
@Configuration
public class MyBatisConfig {
@Resource
private MyTenantHandler myTenantHandler;
@Bean
public PaginationInterceptor paginationInterceptor() {
PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
paginationInterceptor.setSqlParserList(Collections.singletonList(new TenantSqlParser().setTenantHandler(myTenantHandler)));
return paginationInterceptor;
}
}