数据权限-SpringJpa拦截示例

需求:已知系统有很多表然后在每个表都有个公共字段,比如叫租户。需要从租户来隔离数据权限

比较直观的方法可能是重写Dao接口的查询,但我们大多场景可能用的JPA或mybatis,写的纯SQL语句来查询业务数据

这时就需要想到使用过滤器,拦截所有查询SQL加入数据权限的过滤,这时要考虑的SQL写法就多了

常见SQL写法

select * from table1 t1 where t1.field = xxx and t1.field2=bbb

select * from table12 tsg  left join table3 tw on tsg.field = tw.id where 1=1

select * from table1 xx1 where field = (select fie from table2 where a=1)

select * FROM table AS xxxx where 1=1

/**还有带函数的写法*/
select id,myfunction(field2) FROM table AS xxxx where 1=1

我们的目标应该是在where条件里加入公共字段过滤达到数据权限控制

此代码示例是在SpringJPA环境下

先配置JPA
spring:
    jpa:
        properties:
          hibernate:
            session_factory:
              statement_inspector: com.xxxJpaInterceptor

JPA拦截代码,可供参考

package com.xxxx;

import com.xxxx.UserVO;
import lombok.extern.slf4j.Slf4j;
import org.hibernate.resource.jdbc.spi.StatementInspector;

import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 数据权限拦截
 *
 */
@Slf4j
public class JpaInterceptor implements StatementInspector {

    @Override
    public String inspect(String sql) {
        UserVO userVO = ThreadLocalUtils.get();
        if(null == userVO){
            log.info("无法获取到登录用户,数据权限不拦截{}",sql);
            return sql;
        }
        //管理员
        if(userVO.isAdmin()){
            log.info("无法获取到登录用户,数据权限不拦截{}",sql);
            return sql;
        }
        /**
         * 进入拦截器select tbluserrol0_.id as id1_5_, tbluserrol0_.role_id as role_id2_5_,
         * tbluserrol0_.user_id as user_id3_5_
         * from tbl_user_role tbluserrol0_ where tbluserrol0_.user_id=28
         * 进入拦截器SELECT rf.function_id,f.function_name,f.level,f.functionurl,f.parent_function_id,f.id,f.ranks
         * FROM tbl_role_function
         * rf LEFT JOIN tbl_function f ON f.id=rf.function_id WHERE rf.role_id=? AND f.`level` IN(1,2) ORDER BY f.ranks ASC
         *
         * select * from table xx wehre a = (select xx from table where aa = aa)
         */
        //如果SQL有租户条件则不入侵
        if(sql.indexOf("tenant_id") == -1){
            Pattern compile = Pattern.compile("\\swhere\\s",Pattern.CASE_INSENSITIVE);
            Matcher matcher = compile.matcher(sql);
            int lastStart = 0;
            String tenantId = userVO.getTenantId();
            StringBuilder builder = new StringBuilder(sql);
            while (matcher.find(lastStart)){
                String sqlTemp = builder.substring(lastStart,matcher.end());
                String tableAs = getTableAs(sqlTemp);
                String text = String.format("%stenant_id = '%s' and ", tableAs, tenantId);
                builder.insert(matcher.end(),text);
                lastStart = matcher.end() + text.length();
                matcher = compile.matcher(builder.toString());
            }
            //没有where
            if(lastStart == 0){
                String tableAs = getTableAs(sql);
                builder.append(String.format("%stenant_id = '%s'",tableAs,tenantId));
            }
            sql = builder.toString();
        }
        return sql;
    }

    public static void main(String[] args) {
        String sql = "select tbluserrol0_.id as id1_5_, tbluserrol0_.role_id as role_id2_5_, tbluserrol0_.user_id as user_id3_5_ from tbl_user_role tbluserrol0_ where tbluserrol0_.user_id=28";
        sql = "SELECT rf.function_id,f.function_name,f.level,f.functionurl,f.parent_function_id,f.id,f.ranks FROM tbl_role_function  rf LEFT JOIN tbl_function f ON f.id=rf.function_id WHERE rf.role_id=? AND f.`level` IN(1,2) ORDER BY f.ranks ASC";
        sql = "select channelent0_.id as id1_2_, channelent0_.access_token as access_t2_2_, channelent0_.app_key as app_key3_2_, channelent0_.app_secret as app_secr4_2_, channelent0_.channel_name as channel_5_2_, channelent0_.channel_type as channel_6_2_, channelent0_.created_time as created_7_2_, channelent0_.customer_number as customer8_2_, channelent0_.expire_time as expire_t9_2_, channelent0_.jd_customer_number as jd_cust10_2_, channelent0_.modify_time as modify_11_2_, channelent0_.msg_secret as msg_sec12_2_, channelent0_.open_id as open_id13_2_, channelent0_.org_number as org_num14_2_, channelent0_.person_order_refund as person_15_2_, channelent0_.refresh_token as refresh16_2_, channelent0_.server_url as server_17_2_, channelent0_.shop_id as shop_id18_2_, channelent0_.shop_name as shop_na19_2_, channelent0_.sign_secret as sign_se20_2_ from tbl_channel channelent0_ where channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=? or channelent0_.shop_id=?";
//        sql = "select supplyinfo0_.id as id1_42_, supplyinfo0_.access_token as access_t2_42_, supplyinfo0_.app_key as app_key3_42_, supplyinfo0_.app_secret as app_secr4_42_, supplyinfo0_.channel_id as channel_5_42_, supplyinfo0_.customer_id as customer6_42_, supplyinfo0_.expire_time as expire_t7_42_, supplyinfo0_.op_name as op_name8_42_, supplyinfo0_.pin as pin9_42_, supplyinfo0_.prefix as prefix10_42_, supplyinfo0_.refresh_token as refresh11_42_, supplyinfo0_.server_url as server_12_42_, supplyinfo0_.supply_id as supply_13_42_, supplyinfo0_.supply_name as supply_14_42_ from tbl_supply supplyinfo0_ where supplyinfo0_.supply_id=123 or supplyinfo0_.supply_id=2121 or supplyinfo0_.supply_id=22222 or supplyinfo0_.supply_id=123";
//        sql = "select * from table xxxx where 1=1";
//        sql = "select * FROM            table      AS xxxx        where 1=1";
//        sql = "select *     FROM         table    AS    xxxx            ";
        sql = "select * FROM table";
//        sql = "select * from tbl_sync_goods tsg  left join tbl_warehouse tw on tsg.warehouse_id = tw.id where 1=1 ";
        sql = "select * from table1 xx1,table2 as xx2 left join tbl_warehouse tw on tsg.warehouse_id = tw.id where 1=1 ";
//        sql = "select * from table1 xx1 where field = (select fie from table2 where a=1)";
        sql = "select * from table1 xx1 where field = (select fie from table2 where a=1) and  field2 = (select fie from table3 as abc where abc.a=1)";
//        String asName = getTableAs(sql);
//        System.out.println("别名:"+asName);

        Pattern compile = Pattern.compile("\\swhere\\s",Pattern.CASE_INSENSITIVE);
        Matcher matcher = compile.matcher(sql);
        int lastStart = 0;
        String tenantId = "000001";
        StringBuilder builder = new StringBuilder(sql);
        while (matcher.find(lastStart)){
            String sqlTemp = builder.substring(lastStart,matcher.end());
            String tableAs = getTableAs(sqlTemp);
            String text = String.format("%stenant_id = %s and ", tableAs, tenantId);
            builder.insert(matcher.end(),text);
            lastStart = matcher.end() + text.length();
            matcher = compile.matcher(builder.toString());
        }
        //没有where
        if(lastStart == 0){
            String tableAs = getTableAs(sql);
            builder.append(String.format("%stenant_id = %s",tableAs,tenantId));
        }
        sql = builder.toString();
        System.out.println("新SQL:"+sql);
    }

    /**
     * 取form后面的别名
     */
    private static String getTableAs(String sql) {
        Pattern compile = Pattern.compile("\\sfrom\\s",Pattern.CASE_INSENSITIVE);
        Matcher matcher = compile.matcher(sql);
        if(!matcher.find()){
            log.info("没检测到SQL中from关键字,数据权限不拦截{}",sql);
            throw new IllegalArgumentException("无法检测到from!");
        }
        String tableStart = sql.substring(matcher.end()).trim();
        Pattern spaceCompile = Pattern.compile("\\s");
        matcher = spaceCompile.matcher(tableStart);
        if(matcher.find()) {
            tableStart = tableStart.substring(matcher.end()).trim();
        }else{
            return " ";//无别名
        }
        //判断是否有AS
        Pattern asCompile = Pattern.compile("^as",Pattern.CASE_INSENSITIVE);
        matcher = asCompile.matcher(tableStart);
        if(matcher.find()) {
            tableStart = tableStart.substring(matcher.end()).trim();
        }
        //table as1
        Pattern compile2 = Pattern.compile("\\s*(,|where|left|rigth|inner)\\s*",Pattern.CASE_INSENSITIVE);
        matcher = compile2.matcher(tableStart);
        String asName = "";
        if(matcher.find()) {
            asName = tableStart.substring(0, matcher.start()).trim();
        }else{
            //取没有条件的别名
            asName = tableStart.trim();
        }
        return asName.length() > 0 ? asName.concat(".") : asName;
    }
}

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
MyBatis-Plus是一个基于MyBatis的增强工具,提供了许多便捷的功能,其中包括数据拦截器(Data Interceptor)。 数据拦截器是MyBatis-Plus提供的一个特性,用于在SQL语句执行前后对数据进行拦截和处理。通过数据拦截器,可以在执行SQL之前对参数进行修改,或者在执行SQL之后对结果进行处理。 使用数据拦截器可以实现一些常见的需求,比如对敏感字段进行加密解密、对某些特定条件进行数据过滤等。 要使用数据拦截器,首先需要创建一个实现了`com.baomidou.mybatisplus.core.plugins.Interceptor`接口的拦截器类。然后,在MyBatis的配置文件中配置该拦截器: ```xml <configuration> <plugins> <plugin interceptor="com.example.MyInterceptor"/> </plugins> </configuration> ``` 其中`com.example.MyInterceptor`是你自定义的拦截器类的全限定名。 在自定义的拦截器类中,你可以通过重写`intercept`方法来实现对SQL执行前后的处理逻辑。`intercept`方法接收一个`Invocation`对象作为参数,通过该对象可以获取到SQL语句、参数等相关信息。 ```java public class MyInterceptor implements Interceptor { @Override public Object intercept(Invocation invocation) throws Throwable { // 在SQL执行前的处理逻辑 // ... // 执行SQL Object result = invocation.proceed(); // 在SQL执行后的处理逻辑 // ... return result; } } ``` 需要注意的是,如果你使用的是Spring Boot,可以通过`@Bean`注解将拦截器类注入到Spring容器中。如果是非Spring Boot项目,则需要在MyBatis的配置文件中显式配置拦截器。 以上就是使用MyBatis-Plus数据拦截器的基本介绍,希望对你有所帮助。如有更多问题,请继续提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值