java mybatis-plus 实现企业数据隔离方案

java mybatis-plus 数据隔离方案

    <dependency>
        <groupId>com.baomidou</groupId>
        <artifactId>mybatis-plus-boot-starter</artifactId>
    </dependency>
package com.seerbigdata.common.interceptor;

import com.seerbigdata.common.config.EnterpriseConfig;
import com.seerbigdata.common.core.domain.model.LoginUser;
import com.seerbigdata.common.utils.SecurityUtils;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

// @Intercepts 是标记为拦截器的一个注解
@Intercepts({
        // @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        //type为需要拦截的SQL类型,method 为具体的执行方法(如query,update,commit)  args里面参数为Executor里面的具体方法的class类,
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                org.apache.ibatis.session.RowBounds.class, org.apache.ibatis.session.ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class
                , org.apache.ibatis.session.RowBounds.class, org.apache.ibatis.session.ResultHandler.class,
                org.apache.ibatis.cache.CacheKey.class, BoundSql.class})
})
@Data
@AllArgsConstructor
public class EnterpriseIdInterceptor implements Interceptor {

    private EnterpriseConfig enterpriseConfig;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 获取 enterpriseId
        String enterpriseId = getEnterpriseIdFromRequest();
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object parameter = invocation.getArgs()[1];
        // 获取 SQL 语句
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        String exSql = mappedStatement.getBoundSql(parameter).getSql();
        // 解析SQL语句,获取操作的表名称
        Pattern pattern = Pattern.compile("FROM\\s+[`,\\w]+", Pattern.CASE_INSENSITIVE);
        Matcher matcher = pattern.matcher(exSql);
        if (matcher.find()) {
            String tableName = matcher.group().replaceAll("FROM\\s+", "");
            // 查看配置是否开启企业隔离
            if (Objects.isNull(this.getEnterpriseConfig()) || Objects.isNull(this.getEnterpriseConfig().getEnable())
                    || !this.getEnterpriseConfig().getEnable()) {
                return invocation.proceed();
            }
            List<String> tableNames = Arrays.asList(this.getEnterpriseConfig().getTableName());
            // 查看当前的SQL的表是否是需要隔离的表
            if (!tableNames.contains(tableName)) {
                // 执行原始方法
                return invocation.proceed();
            }
        } else {
            // 执行原始方法
            return invocation.proceed();
        }
        // 检查 enterpriseId 是否存在
        if (enterpriseId != null) {
            // 拼接 SQL 语句
            String modifiedSql = "";
            if (exSql.toLowerCase().contains("where")) {
                modifiedSql = exSql + " AND enterprise_id = '" + enterpriseId + "'";
            } else {
                modifiedSql = exSql + " where enterprise_id = '" + enterpriseId + "'";
            }
            // 更新SQL语句
            Field sql = boundSql.getClass().getDeclaredField("sql");
            // 反射获取私有字段,并重新设置值
            sql.setAccessible(true);
            sql.set(boundSql, modifiedSql);
            StaticSqlSource boundSqlSqlSource = new StaticSqlSource(mappedStatement.getConfiguration(), modifiedSql, mappedStatement.getParameterMap().getParameterMappings());
            // 替换拦截器获取的MappedStatement类
            args[0] = copyFromMappedStatement(mappedStatement, boundSqlSqlSource);
            // 替换拦截器获取的BoundSql类
            args[5] = boundSql;
            // 执行原始方法更新的方法
            return invocation.proceed();
        }
        // 执行原始方法
        return invocation.proceed();
    }

    /**
     * 回塞sql
     *
     * @param ms           MappedStatement
     * @param newSqlSource SqlSource
     * @return MappedStatement
     */
    private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 可以设置一些额外的属性,但在这个例子中我们不需要任何额外的属性  
    }

    // 一个帮助方法,用于从 HttpServletRequest 中获取 enterpriseId  (根据用户登录信息获取ID)
    private String getEnterpriseIdFromRequest() {
        // 获取 HttpServletRequest 中的 enterprise_enable 字段
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (Objects.isNull(requestAttributes)) {
            return null;
        }
        // 前端的隔离开关
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String enable = request.getHeader("enterprise_enable");
        if (StringUtils.hasText(enable) && "false".equals(enable)) {
            return null;
        }
        // 获取登录用户的企业ID
        LoginUser loginUser = SecurityUtils.getLoginUser();
        if (Objects.isNull(loginUser)) {
            return null;
        }
        if (loginUser.getEnterpriseUserFlag() && !CollectionUtils.isEmpty(loginUser.getEnterprises())) {
            return loginUser.getEnterprises().get(0).getEnterpriseId();
        }
        return null; // 初始实现返回 null,你需要替换这部分逻辑
    }
}

package com.seerbigdata.common.config;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import com.seerbigdata.common.handler.CustomTenantHandler;
import com.seerbigdata.common.interceptor.EnterpriseIdInterceptor;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.DatabaseIdProvider;
import org.apache.ibatis.mapping.VendorDatabaseIdProvider;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;

import javax.sql.DataSource;
import java.util.Properties;

/**
 * @author hzk
 * @Classname config
 * @create 2022-03-01 16:25
 * @Description
 */
@Configuration
@Slf4j
public class MybatisPlusConfig {

    @Value("${mybatis-plus.type-aliases-package}")
    private String typeAliasesPackage;

    @Value("${mybatis-plus.mapper-locations}")
    private String mapperLocations;

    @Autowired
    private TenantConfig tenantConfig;

    @Autowired
    private CustomTenantHandler customTenantHandler;

	// 添加企业配置
    @Autowired
    private EnterpriseConfig enterpriseConfig;

    @Bean("sqlSessionFactory")
    @Primary
    public SqlSessionFactory sqlSessionFactory(@Autowired @Qualifier("dataSource") DataSource dataSource) throws Exception {

        MybatisSqlSessionFactoryBean sqlSessionFactoryBean = new MybatisSqlSessionFactoryBean();
        sqlSessionFactoryBean.setDataSource(dataSource);
        sqlSessionFactoryBean.setTypeAliasesPackage(typeAliasesPackage);
        sqlSessionFactoryBean.setMapperLocations(new PathMatchingResourcePatternResolver().getResources(mapperLocations));
        // 设置全局配置
        sqlSessionFactoryBean.setGlobalConfig(globalConfig());
        // 分页插件配置 以及 添加企业配置
        sqlSessionFactoryBean.setPlugins(mybatisPlusInterceptor(), enterpriseIdInterceptor());
        return sqlSessionFactoryBean.getObject();

    }

    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();

        if (Boolean.TRUE.equals(tenantConfig.getEnable())) {
            log.info("开启租户模式");
            interceptor.addInnerInterceptor(new TenantLineInnerInterceptor(customTenantHandler));
        }

        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.KINGBASE_ES));
        return interceptor;
    }

    /**
     * 获取拦截器的BEAN,用于注册到分页插件拦截器
     * @return
     */
    public EnterpriseIdInterceptor enterpriseIdInterceptor() {
        EnterpriseIdInterceptor interceptor = new EnterpriseIdInterceptor(enterpriseConfig);
        return interceptor;
    }

    /**
     * 全局配置
     *
     * @return
     */
    public GlobalConfig globalConfig() {
        GlobalConfig globalConfig = new GlobalConfig();
        globalConfig.setDbConfig(dbConfig());
        globalConfig.setMetaObjectHandler(seerMetaObjectHandler());
        return globalConfig;
    }

    /**
     * 全局db配置
     *
     * @return
     */
    public GlobalConfig.DbConfig dbConfig() {
        GlobalConfig.DbConfig dbConfig = new GlobalConfig.DbConfig();
        dbConfig.setLogicDeleteField("deleted");
        dbConfig.setLogicDeleteValue("1");
        dbConfig.setLogicNotDeleteValue("0");
        return dbConfig;
    }

    /**
     * Mybatis Plus 自动注入器
     *
     * @return 注入器
     */
    public MetaObjectHandler seerMetaObjectHandler() {
        log.info("SeerMetaObjectHandler loaded");
        return new SeerMetaObjectHandler();
    }


    /**
     * 自动识别使用的数据库类型
     * 在mapper.xml中databaseId的值就是跟这里对应,
     * 如果没有databaseId选择则说明该sql适用所有数据库
     */
    @Bean
    public DatabaseIdProvider getDatabaseIdProvider() {
        DatabaseIdProvider databaseIdProvider = new VendorDatabaseIdProvider();
        Properties properties = new Properties();
        properties.setProperty("Oracle", "oracle");
        properties.setProperty("MySQL", "mysql");
        properties.setProperty("DB2", "db2");
        properties.setProperty("Derby", "derby");
        properties.setProperty("H2", "h2");
        properties.setProperty("HSQL", "hsql");
        properties.setProperty("Informix", "informix");
        properties.setProperty("MS-SQL", "ms-sql");
        properties.setProperty("PostgreSQL", "postgresql");
        properties.setProperty("Sybase", "sybase");
        properties.setProperty("Hana", "hana");
        databaseIdProvider.setProperties(properties);
        return databaseIdProvider;
    }
}

package com.seerbigdata.common.config;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;

/**
 * 企业配置类
 *
 * @author cwj
 */
@Data
@Configuration
@ConfigurationProperties(prefix = "enterprise")
public class EnterpriseConfig {
    /**
     * 是否开启企业模式
     */
    private Boolean enable;

    /**
     * 需要隔离的表
     */
    private String[] tableName;
}

yml文件:
enterprise:
####### 是否开启企业模式
enable: true
######## 需要做数据隔离的表名
tableName:
- sp_special_work_safety_disclosure
- sp_special_work_person
- t_event
- device_camera

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值