java mybatis-plus 数据隔离方案
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-boot-starter</artifactId>
</dependency>
package com.seerbigdata.common.interceptor;
import com.seerbigdata.common.config.EnterpriseConfig;
import com.seerbigdata.common.core.domain.model.LoginUser;
import com.seerbigdata.common.utils.SecurityUtils;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
// @Intercepts 是标记为拦截器的一个注解
@Intercepts({
// @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
//type为需要拦截的SQL类型,method 为具体的执行方法(如query,update,commit) args里面参数为Executor里面的具体方法的class类,
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
org.apache.ibatis.session.RowBounds.class, org.apache.ibatis.session.ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class
, org.apache.ibatis.session.RowBounds.class, org.apache.ibatis.session.ResultHandler.class,
org.apache.ibatis.cache.CacheKey.class, BoundSql.class})
})
@Data
@AllArgsConstructor
public class EnterpriseIdInterceptor implements Interceptor {
private EnterpriseConfig enterpriseConfig;
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 获取 enterpriseId
String enterpriseId = getEnterpriseIdFromRequest();
Object[] args = invocation.getArgs();
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
Object parameter = invocation.getArgs()[1];
// 获取 SQL 语句
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
String exSql = mappedStatement.getBoundSql(parameter).getSql();
// 解析SQL语句,获取操作的表名称
Pattern pattern = Pattern.compile("FROM\\s+[`,\\w]+", Pattern.CASE_INSENSITIVE);
Matcher matcher = pattern.matcher(exSql);
if (matcher.find()) {
String tableName = matcher.group().replaceAll("FROM\\s+", "");
// 查看配置是否开启企业隔离
if (Objects.isNull(this.getEnterpriseConfig()) || Objects.isNull(this.getEnterpriseConfig().getEnable())
|| !this.getEnterpriseConfig().getEnable()) {
return invocation.proceed();
}
List<String> tableNames = Arrays.asList(this.getEnterpriseConfig().getTableName());
// 查看当前的SQL的表是否是需要隔离的表
if (!tableNames.contains(tableName)) {
// 执行原始方法
return invocation.proceed();
}
} else {
// 执行原始方法
return invocation.proceed();
}
// 检查 enterpriseId 是否存在
if (enterpriseId != null) {
// 拼接 SQL 语句
String modifiedSql = "";
if (exSql.toLowerCase().contains("where")) {
modifiedSql = exSql + " AND enterprise_id = '" + enterpriseId + "'";
} else {
modifiedSql = exSql + " where enterprise_id = '" + enterpriseId + "'";
}
// 更新SQL语句
Field sql = boundSql.getClass().getDeclaredField("sql");
// 反射获取私有字段,并重新设置值
sql.setAccessible(true);
sql.set(boundSql, modifiedSql);
StaticSqlSource boundSqlSqlSource = new StaticSqlSource(mappedStatement.getConfiguration(), modifiedSql, mappedStatement.getParameterMap().getParameterMappings());
// 替换拦截器获取的MappedStatement类
args[0] = copyFromMappedStatement(mappedStatement, boundSqlSqlSource);
// 替换拦截器获取的BoundSql类
args[5] = boundSql;
// 执行原始方法更新的方法
return invocation.proceed();
}
// 执行原始方法
return invocation.proceed();
}
/**
* 回塞sql
*
* @param ms MappedStatement
* @param newSqlSource SqlSource
* @return MappedStatement
*/
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
builder.keyProperty(ms.getKeyProperties()[0]);
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
// 可以设置一些额外的属性,但在这个例子中我们不需要任何额外的属性
}
// 一个帮助方法,用于从 HttpServletRequest 中获取 enterpriseId (根据用户登录信息获取ID)
private String getEnterpriseIdFromRequest() {
// 获取 HttpServletRequest 中的 enterprise_enable 字段
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (Objects.isNull(requestAttributes)) {
return null;
}
// 前端的隔离开关
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String enable = request.getHeader("enterprise_enable");
if (StringUtils.hasText(enable) && "false".equals(enable)) {
return null;
}
// 获取登录用户的企业ID
LoginUser loginUser = SecurityUtils.getLoginUser();
if (Objects.isNull(loginUser)) {
return null;
}
if (loginUser.getEnterpriseUserFlag() && !CollectionUtils.isEmpty(loginUser.getEnterprises())) {
return loginUser.getEnterprises().get(0).getEnterpriseId();
}
return null; // 初始实现返回 null,你需要替换这部分逻辑
}
}
package com.seerbigdata.common.config;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import com.seerbigdata.common.handler.CustomTenantHandler;
import com.seerbigdata.common.interceptor.EnterpriseIdInterceptor;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.DatabaseIdProvider;
import org.apache.ibatis.mapping.VendorDatabaseIdProvider;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import javax.sql.DataSource;
import java.util.Properties;
/**
* @author hzk
* @Classname config
* @create 2022-03-01 16:25
* @Description
*/
@Configuration
@Slf4j
public class MybatisPlusConfig {
@Value("${mybatis-plus.type-aliases-package}")
private String typeAliasesPackage;
@Value("${mybatis-plus.mapper-locations}")
private String mapperLocations;
@Autowired
private TenantConfig tenantConfig;
@Autowired
private CustomTenantHandler customTenantHandler;
// 添加企业配置
@Autowired
private EnterpriseConfig enterpriseConfig;
@Bean("sqlSessionFactory")
@Primary
public SqlSessionFactory sqlSessionFactory(@Autowired @Qualifier("dataSource") DataSource dataSource) throws Exception {
MybatisSqlSessionFactoryBean sqlSessionFactoryBean = new MybatisSqlSessionFactoryBean();
sqlSessionFactoryBean.setDataSource(dataSource);
sqlSessionFactoryBean.setTypeAliasesPackage(typeAliasesPackage);
sqlSessionFactoryBean.setMapperLocations(new PathMatchingResourcePatternResolver().getResources(mapperLocations));
// 设置全局配置
sqlSessionFactoryBean.setGlobalConfig(globalConfig());
// 分页插件配置 以及 添加企业配置
sqlSessionFactoryBean.setPlugins(mybatisPlusInterceptor(), enterpriseIdInterceptor());
return sqlSessionFactoryBean.getObject();
}
public MybatisPlusInterceptor mybatisPlusInterceptor() {
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
if (Boolean.TRUE.equals(tenantConfig.getEnable())) {
log.info("开启租户模式");
interceptor.addInnerInterceptor(new TenantLineInnerInterceptor(customTenantHandler));
}
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.KINGBASE_ES));
return interceptor;
}
/**
* 获取拦截器的BEAN,用于注册到分页插件拦截器
* @return
*/
public EnterpriseIdInterceptor enterpriseIdInterceptor() {
EnterpriseIdInterceptor interceptor = new EnterpriseIdInterceptor(enterpriseConfig);
return interceptor;
}
/**
* 全局配置
*
* @return
*/
public GlobalConfig globalConfig() {
GlobalConfig globalConfig = new GlobalConfig();
globalConfig.setDbConfig(dbConfig());
globalConfig.setMetaObjectHandler(seerMetaObjectHandler());
return globalConfig;
}
/**
* 全局db配置
*
* @return
*/
public GlobalConfig.DbConfig dbConfig() {
GlobalConfig.DbConfig dbConfig = new GlobalConfig.DbConfig();
dbConfig.setLogicDeleteField("deleted");
dbConfig.setLogicDeleteValue("1");
dbConfig.setLogicNotDeleteValue("0");
return dbConfig;
}
/**
* Mybatis Plus 自动注入器
*
* @return 注入器
*/
public MetaObjectHandler seerMetaObjectHandler() {
log.info("SeerMetaObjectHandler loaded");
return new SeerMetaObjectHandler();
}
/**
* 自动识别使用的数据库类型
* 在mapper.xml中databaseId的值就是跟这里对应,
* 如果没有databaseId选择则说明该sql适用所有数据库
*/
@Bean
public DatabaseIdProvider getDatabaseIdProvider() {
DatabaseIdProvider databaseIdProvider = new VendorDatabaseIdProvider();
Properties properties = new Properties();
properties.setProperty("Oracle", "oracle");
properties.setProperty("MySQL", "mysql");
properties.setProperty("DB2", "db2");
properties.setProperty("Derby", "derby");
properties.setProperty("H2", "h2");
properties.setProperty("HSQL", "hsql");
properties.setProperty("Informix", "informix");
properties.setProperty("MS-SQL", "ms-sql");
properties.setProperty("PostgreSQL", "postgresql");
properties.setProperty("Sybase", "sybase");
properties.setProperty("Hana", "hana");
databaseIdProvider.setProperties(properties);
return databaseIdProvider;
}
}
package com.seerbigdata.common.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
/**
* 企业配置类
*
* @author cwj
*/
@Data
@Configuration
@ConfigurationProperties(prefix = "enterprise")
public class EnterpriseConfig {
/**
* 是否开启企业模式
*/
private Boolean enable;
/**
* 需要隔离的表
*/
private String[] tableName;
}
yml文件:
enterprise:
####### 是否开启企业模式
enable: true
######## 需要做数据隔离的表名
tableName:
- sp_special_work_safety_disclosure
- sp_special_work_person
- t_event
- device_camera