本插件基于数据库字段进行数据的隔离,字段为tenantId
1.在原有项目的所有接口中加入租户字段,一个一个改动工作量巨大,考虑采用过滤器拦截请求把请求头参数中的参数织入到sql中的where条件中。
2.请求拦截器
/**
* 拦截器配置
*/
@Configuration
public class WebHeadFilter {
@Bean
public FilterRegistrationBean modifyParametersFilter() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(new MyHeadFilter());
registration.addUrlPatterns("/*"); // 拦截路径
registration.setName("headParametersFilter"); // 拦截器名称
registration.setOrder(1); // 顺序
return registration;
}
/**
* 自定义拦截器
*/
class MyHeadFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
// tenantId信息处理
HeadRequestWrapper headRequestWrapper = new HeadRequestWrapper(request);
if (StringUtils.isBlank(headRequestWrapper.getHeader("tenantId"))) {
headRequestWrapper.addHead("tenantId", "01b93ddbb36b4fe69c789595ab686597");
}
TenantInfoUtil.set(headRequestWrapper.getHeader("tenantId"));
// finish
filterChain.doFilter(headRequestWrapper, response);
}
@Override
public void destroy() {
TenantInfoUtil.remove();
super.destroy();
}
}
/**
* 修改head信息
*/
class HeadRequestWrapper extends HttpServletRequestWrapper {
private final Map<String, String> headers;
HeadRequestWrapper(HttpServletRequest request) {
super(request);
this.headers = new HashMap<>();
}
@Override
public String getHeader(String name) {
String headervalue = super.getHeader(name);
if (headers.containsKey(name)) {
headervalue = headers.get(name);
}
return headervalue;
}
@Override
public Enumeration<String> getHeaderNames() {
List<String> values = Collections.list(super.getHeaderNames());
for (String value : headers.keySet()) {
values.add(value);
}
return Collections.enumeration(values);
}
@Override
public Enumeration<String> getHeaders(String name) {
List<String> values = Collections.list(super.getHeaders(name));
if (headers.containsKey(name)) {
values = Arrays.asList(headers.get(name));
}
return Collections.enumeration(values);
}
public void addHead(String name, String value) {
this.headers.put(name, value);
}
}
}
3.租户信息工具类
public class TenantInfoUtil {
// 构造方法私有化
private TenantInfoUtil(){}
private static final ThreadLocal<String> context = new ThreadLocal<>();
/**
* 存放用户信息
* @param tenantInfo
*/
public static void set(String tenantInfo){
context.set(tenantInfo);
}
/**
* 获取租户信息
* @return
*/
public static String get(){
return context.get();
}
/**
* 清除当前线程内引用,防止内存泄漏
*/
public static void remove(){
context.remove();
}
}
4.多租户插件
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class MultiTenantPlugin implements Interceptor {
/**
* 当前数据库的方言
*/
private String dialect;
/**
* 多租户字段名称
*/
private String tenantIdField;
/**
* 需要识别多租户字段的表名称列表
*/
private Set<String> tableSet;
/**
* sql语句工具
*/
private SqlConditionUtil sqlConditionUtil;
@Override
public Object intercept(Invocation invocation) throws Throwable {
String tenantId = TenantInfoUtil.get();
log.info("---tenantId:------->" + tenantId);
//租户id为空时不做处理
if (StringUtils.isBlank(tenantId)) {
return invocation.proceed();
}
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
//原始的boundsql
log.info("------boundSql:-----" + boundSql.getSql());
//把新sql设置到boundSql
String newSql = addTenantCondition(boundSql.getSql(), tenantId);
log.info("--------newSql:---" + newSql);
//把新sql设置到boundSql
metaObject.setValue("delegate.boundSql.sql", newSql);
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
dialect = properties.getProperty("dialect");
if (StringUtils.isBlank(dialect)) {
throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
}
tenantIdField = properties.getProperty("tenantIdField");
if (StringUtils.isBlank(tenantIdField)) {
throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
}
String tableNames = properties.getProperty("tableNames");
if (!StringUtils.isBlank(tableNames)) {
tableSet = new HashSet<String>(Arrays.asList(StringUtils.split(tableNames, ",")));
}
if (tableSet == null) {
throw new IllegalArgumentException("MultiTenantPlugin tableNames must have");
}
/**
* 多租户条件字段决策器
*/
ITableFieldConditionDecision conditionDecision = new ITableFieldConditionDecision() {
@Override
public boolean isAllowNullValue() {
return false;
}
@Override
public boolean adjudge(String tableName, String fieldName) {
if (tableSet != null && tableSet.contains(tableName)) {
return true;
}
return false;
}
};
sqlConditionUtil = new SqlConditionUtil(conditionDecision);
}
/**
* 给sql语句where添加租户id过滤条件
*
* @param sql 要添加过滤条件的sql语句
* @param tenantId 当前的租户id
* @return 添加条件后的sql语句
*/
private String addTenantCondition(String sql, String tenantId) {
if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) {
return sql;
}
List<SQLStatement> statementList = SQLUtils.parseStatements(sql, dialect);
if (statementList == null || statementList.size() == 0) {
return sql;
}
SQLStatement sqlStatement = statementList.get(0);
sqlConditionUtil.addStatementCondition(sqlStatement, tenantIdField, tenantId);
return SQLUtils.toSQLString(statementList, dialect);
}
}
5.sql语句的拼接工具类
public class SqlConditionUtil {
private ITableFieldConditionDecision conditionDecision;
public SqlConditionUtil(ITableFieldConditionDecision conditionDecision) {
this.conditionDecision = conditionDecision;
}
/**
* 为sql'语句添加指定where条件
*
* @param sqlStatement
* @param fieldName
* @param fieldValue
*/
public void addStatementCondition(SQLStatement sqlStatement, String fieldName, String fieldValue) {
if (sqlStatement instanceof SQLSelectStatement) {
SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
addSelectStatementCondition(queryObject, queryObject.getFrom(), fieldName, fieldValue);
} else if (sqlStatement instanceof SQLUpdateStatement) {
SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
addUpdateStatementCondition(updateStatement, fieldName, fieldValue);
} else if (sqlStatement instanceof SQLDeleteStatement) {
SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
addDeleteStatementCondition(deleteStatement, fieldName, fieldValue);
} else if (sqlStatement instanceof SQLInsertStatement) {
SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
addInsertStatementCondition(insertStatement, fieldName, fieldValue);
}
}
/**
* 为insert语句添加where条件
*
* @param insertStatement
* @param fieldName
* @param fieldValue
*/
private void addInsertStatementCondition(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
if (insertStatement != null) {
SQLInsertInto sqlInsertInto = insertStatement;
SQLSelect sqlSelect = sqlInsertInto.getQuery();
if (sqlSelect != null) {
SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
} else {
//处理插入是没有包含字段得情况
if (!sqlInsertInto.getColumns().stream().anyMatch(e -> fieldName.equalsIgnoreCase(e.clone().toString()))) {
sqlInsertInto.getColumns().add(new SQLIdentifierExpr(fieldName));
sqlInsertInto.getValuesList().get(0).addValue(new SQLCharExpr(fieldValue));
}
}
}
}
/**
* 为delete语句添加where条件
*
* @param deleteStatement
* @param fieldName
* @param fieldValue
*/
private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, String fieldName, String fieldValue) {
SQLExpr where = deleteStatement.getWhere();
//添加子查询中的where条件
addSQLExprCondition(where, fieldName, fieldValue);
SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
deleteStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
deleteStatement.setWhere(newCondition);
}
/**
* where中添加指定筛选条件
*
* @param where 源where条件
* @param fieldName
* @param fieldValue
*/
private void addSQLExprCondition(SQLExpr where, String fieldName, String fieldValue) {
if (where instanceof SQLInSubQueryExpr) {
SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
SQLSelect subSelectObject = inWhere.getSubQuery();
SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
} else if (where instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
SQLExpr left = opExpr.getLeft();
SQLExpr right = opExpr.getRight();
addSQLExprCondition(left, fieldName, fieldValue);
addSQLExprCondition(right, fieldName, fieldValue);
} else if (where instanceof SQLQueryExpr) {
SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
}
}
/**
* 为update语句添加where条件
*
* @param updateStatement
* @param fieldName
* @param fieldValue
*/
private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, String fieldName, String fieldValue) {
SQLExpr where = updateStatement.getWhere();
//添加子查询中的where条件
addSQLExprCondition(where, fieldName, fieldValue);
SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
updateStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
updateStatement.setWhere(newCondition);
}
/**
* 给一个查询对象添加一个where条件
*
* @param queryObject
* @param fieldName
* @param fieldValue
*/
private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, String fieldName, String fieldValue) {
if (StringUtils.isBlank(fieldName) || from == null || queryObject == null) {
return;
}
SQLExpr originCondition = queryObject.getWhere();
if (from instanceof SQLExprTableSource) {
String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
String alias = from.getAlias();
SQLExpr newCondition = newEqualityCondition(tableName, alias, fieldName, fieldValue, originCondition);
queryObject.setWhere(newCondition);
} else if (from instanceof SQLJoinTableSource) {
SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
SQLTableSource left = joinObject.getLeft();
SQLTableSource right = joinObject.getRight();
addSelectStatementCondition(queryObject, left, fieldName, fieldValue);
addSelectStatementCondition(queryObject, right, fieldName, fieldValue);
} else if (from instanceof SQLSubqueryTableSource) {
SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
} else if (from instanceof SQLUnionQueryTableSource) {
SQLUnionQueryTableSource union = (SQLUnionQueryTableSource) from;
SQLUnionQuery sqlUnionQuery = union.getUnion();
//这里判断查询类型
addSelectStatementConditionUnion(queryObject, sqlUnionQuery , fieldName, fieldValue);
} else {
throw new ServiceException("sql增强异常");
}
}
/**
* 拼接union查询的租户字段
* @param queryObject
* @param sqlUnionQuery
* @param fieldName
* @param fieldValue
*/
private void addSelectStatementConditionUnion(SQLSelectQueryBlock queryObject, SQLUnionQuery sqlUnionQuery, String fieldName, String fieldValue) {
if(sqlUnionQuery.getLeft() instanceof SQLUnionQuery) {
SQLUnionQuery temQuery= (SQLUnionQuery)sqlUnionQuery.getLeft();
addSelectStatementConditionUnion(queryObject, temQuery , fieldName, fieldValue);
}
if(sqlUnionQuery.getLeft() instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock left= (SQLSelectQueryBlock)sqlUnionQuery.getLeft();
addSelectStatementCondition(left, left.getFrom() , fieldName, fieldValue);
}
if(sqlUnionQuery.getRight() instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock right = (SQLSelectQueryBlock) sqlUnionQuery.getRight();
addSelectStatementCondition(right, right.getFrom(), fieldName, fieldValue);
}
}
/**
* 根据原来的condition创建一个新的condition
*
* @param tableName 表名称
* @param tableAlias 表别名
* @param fieldName
* @param fieldValue
* @param originCondition
* @return
*/
private SQLExpr newEqualityCondition(String tableName, String tableAlias, String fieldName, String fieldValue, SQLExpr originCondition) {
//如果不需要设置条件
if (!conditionDecision.adjudge(tableName, fieldName)) {
return originCondition;
}
//如果条件字段不允许为空
if (fieldValue == null && !conditionDecision.isAllowNullValue()) {
return originCondition;
}
String filedName = StringUtils.isBlank(tableAlias) ? fieldName : tableAlias + "." + fieldName;
SQLExpr condition = new SQLBinaryOpExpr(new SQLIdentifierExpr(filedName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, condition, false, originCondition);
}
public static void main(String[] args) {
// String sql = "select * from user s ";
// String sql = "select * from user s where s.name='333'";
// String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
// String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";
// String sql = "update user set name=? where id =(select id from user s)";
// String sql = "delete from user where id = ( select id from user s )";
// String sql = "INSERT INTO deleted_organization ( id,origin_id,org_name,org_type,org_level,org_order,super_org_id,org_number,email,company_phone,is_structure,is_deleted,tag,data_source_flag,create_user_name,update_user_name,create_time,update_time,delete_time, tenant_id) VALUES( 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,8) ";
String sql = "SELECT count(0) FROM ((SELECT DISTINCT data_source_flag FROM t_account) UNION (SELECT DISTINCT data_source_flag FROM t_asset) UNION (SELECT DISTINCT data_source_flag FROM t_business_system) UNION (SELECT DISTINCT data_source_flag FROM t_component) UNION (SELECT DISTINCT data_source_flag FROM t_component_account) UNION (SELECT DISTINCT data_source_flag FROM t_employee) UNION (SELECT DISTINCT data_source_flag FROM t_intranet_ip_network_segment) UNION (SELECT DISTINCT data_source_flag FROM t_ip) UNION (SELECT DISTINCT data_source_flag FROM t_leak) UNION (SELECT DISTINCT data_source_flag FROM t_organization) UNION (SELECT DISTINCT data_source_flag FROM t_security_domain) UNION (SELECT DISTINCT data_source_flag FROM t_virus)) AS c LEFT JOIN d_enum e ON c.data_source_flag = e.code WHERE c.data_source_flag IS NOT NULL ";
// String sql = "SELECT DISTINCT data_source_flag FROM t_account UNION SELECT DISTINCT data_source_flag FROM t_asset";
// String sql = "SELECT DISTINCT data_source_flag FROM t_account";
// String sql = "select u.*,g.name from user u join (select * from user_group g join user_role r on g.role_code=r.code ) g on u.groupId=g.groupId where u.name='123'";
List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
SQLStatement sqlStatement = statementList.get(0);
//决策器定义
SqlConditionUtil helper = new SqlConditionUtil(new ITableFieldConditionDecision() {
@Override
public boolean adjudge(String tableName, String fieldName) {
return true;
}
@Override
public boolean isAllowNullValue() {
return false;
}
});
//添加多租户条件,tenant_id是字段,yay是筛选值
helper.addStatementCondition(sqlStatement, "tenant_id", "yay");
System.out.println("源sql:" + sql);
System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.POSTGRESQL));
}
}