mybatis自定义拦截器

一、基本使用

1.1、定义枚举值

/**
 * 数据权限类型
 */
@AllArgsConstructor
@Getter
public enum DataPermissionType {

    COMPANY,
    SUPPLIER; 

}

1.2、添加注解

/**
 * 数据权限
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataPermission {

    /**
     * 数据权限类型:默认 供应商
     * @return
     */
    DataPermissionType type() default DataPermissionType.SUPPLIER;

    /**
     * 数据权限字段:默认vendor_code
     * @return
     */
    String permissionCol() default "vendor_code";

}

1.3、自定义拦截器(一)


/**
 * mybatis自定义拦截器
 */
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
public class MyBatisInterceptor implements Interceptor {

    public MyBatisInterceptor(DataPermissionHandler dataPermissionHandler) {
        this.dataPermissionHandler = dataPermissionHandler;
    }

    private DataPermissionHandler dataPermissionHandler;
    /**
     * 数据权限注解
     */
    private DataPermission permission;
    /**
     * 数据权限集合
     */
    private List<String> permissionsData;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //执行的目标对象、方法、参数
        Object target = invocation.getTarget();
        Method method = invocation.getMethod();
        Object[] args = invocation.getArgs();
        //区分执行目标对象
        if (target instanceof Executor && method.getName().equals("query") && args.length == 4) {
            //获取注解
            MappedStatement ms = (MappedStatement) args[0];
            permission = getPermission(ms.getId());
            if (Objects.isNull(permission)) return invocation.proceed();
            return doQuery(invocation);
        }
        return invocation.proceed();
    }


    /**
     * 构建 in 查询条件
     *
     * @param originalSql
     * @return
     */
    private String getDataPermissionSql(String originalSql, DataPermission permission) throws JSQLParserException {
        if (CollectionUtils.isEmpty(permissionsData)) return originalSql;
        StringBuffer sb = new StringBuffer(" select ");
        sb.append(getSelectItems(originalSql));
        sb.append("  from ( ");
        sb.append(originalSql);
        sb.append(" ) a_table ");
        switch (permission.type()) {
            case COMPANY:
                break;
            case SUPPLIER:
                sb.append(getSupplierSql(permissionsData));
                break;
        }
        return sb.toString();
    }

    /**
     * 获取所有字段
     *
     * @param originalSql
     * @return
     * @throws JSQLParserException
     */
    private String getSelectItems(String originalSql) throws JSQLParserException {
        Statement statement = CCJSqlParserUtil.parse(originalSql);
        Select select = (Select) statement;
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            PlainSelect plainSelectBody = (PlainSelect) selectBody;
            List<SelectItem> selectItems = plainSelectBody.getSelectItems();
            List<SelectExpressionItem> selectExpressionItems = Arrays.asList(plainSelectBody.getSelectItems().toArray(new SelectExpressionItem[0]));
            List<String> columns = new ArrayList<>();
            for (SelectExpressionItem selectItem : selectExpressionItems) {
                if (Objects.nonNull(selectItem.getAlias())) {
                    columns.add(selectItem.getAlias().getName());
                    continue;
                }
                if (Objects.nonNull(selectItem.getExpression())) {
                    Column column = (Column) selectItem.getExpression();
                    columns.add(column.getColumnName());
                    continue;
                }
            }
            return StringUtils.join(columns, ",");
        }
        return StringUtils.EMPTY;
    }

    /**
     * 供应商数据权限sql
     *
     * @param list
     * @return
     */
    private String getSupplierSql(List<String> list) {
        StringBuffer sb = new StringBuffer(" WHERE EXISTS ( SELECT vendor_code FROM VENDOR_MASTER_DATA b_table WHERE ( ");
        sb.append(getWhereInSql("vendor_code", list));
        sb.append(" ) AND b_table.vendor_code = a_table.vendor_code ) ");
        return sb.toString();
    }

    /**
     * 获取where in条件
     *
     * @param id   字段
     * @param list 参数值
     * @return
     */
    private String getWhereInSql(String id, List<String> list) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < list.size(); i++) {
            if (i == 0) {
                sb.append(id);
                sb.append(" in (");
            }
            sb.append("'");
            sb.append(list.get(i).toString());
            sb.append("'");
            if (i >= 900 && i < list.size() - 1) {
                if (i % 900 == 0) {
                    sb.append(") or ");
                    sb.append(id);
                    sb.append(" in (");
                } else {
                    sb.append(",");
                }
            } else {
                if (i < list.size() - 1) {
                    sb.append(",");
                }
            }
            if (i == list.size() - 1) {
                sb.append(")");
            }
        }
        return sb.toString();
    }


    /**
     * 设置查询参数
     *
     * @param invocation
     * @return
     * @throws ClassNotFoundException
     * @throws SQLException
     */
    private Object doQuery(Invocation invocation) throws ClassNotFoundException, JSQLParserException,
            InvocationTargetException, IllegalAccessException {
        //执行方法参数
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        //获取数据权限值
        permissionsData = dataPermissionHandler.getDataPermissionValue(permission.type());
        //创建新的sql执行对象
        if (CollectionUtil.isNotEmpty(permissionsData)) newBoundSql(invocation);
        //处理返回结果
        Object result = invocation.proceed();
        //因为前面的方法存在查询操作,再次经过了拦截器,更新了permission,所以需要在查询一次
        permission = getPermission(ms.getId());
        setColumnPermissions(result, permission);
        return result;
    }

    /**
     * 设置字段数据权限
     *
     * @param result
     */
    private void setColumnPermissions(Object result, DataPermission dataPermission) {
        if (!permission.columnPermission()) return;
        //获取字段数据权限
        permissionsData = dataPermissionHandler.getColumnPermission();
        if (CollectionUtil.isEmpty(permissionsData)) return;
        //设置权限字段值为空
        if (result instanceof ArrayList && Objects.nonNull(result)) {
            List arrayData = (List) result;
            for (Object data : arrayData) {
                for (String columnName : permissionsData) {
                    //判断权限字段是否存在
                    if (!ReflectionUtils.isExistField(data.getClass(), columnName)) continue;
                    ReflectionUtils.setFieldValue(data, columnName, null);
                }
            }
        }
    }

    /**
     * 创建新的sql对象并更新 Invocation
     *
     * @param invocation
     * @return
     * @throws JSQLParserException
     */
    private void newBoundSql(Invocation invocation) throws JSQLParserException, ClassNotFoundException {
        //获取原来的sql对象
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object paramObj = args[1];
        BoundSql boundSql = ms.getBoundSql(paramObj);
        permission = getPermission(ms.getId());
        //构造新的sql对象
        String dataPermissionSql = getDataPermissionSql(boundSql.getSql(), permission);
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), dataPermissionSql, boundSql.getParameterMappings(),
                boundSql.getParameterObject());
        //见原来sql对象参数放到新的sql对象中
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }
        //用新的sql对象构造 statement
        MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
        //将statement放回到原来的 invocation参数列表中
        Object[] queryArgs = invocation.getArgs();
        queryArgs[0] = newMs;
    }

    /**
     * 构建sqlSource
     */
    class BoundSqlSqlSource implements SqlSource {
        private BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

    /**
     * 创建新的statement
     *
     * @param ms           原来的statement
     * @param newSqlSource 新的sql源
     * @return
     */
    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(),
                newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }


    /**
     * 获取注解
     *
     * @param fullMethodName
     * @return
     * @throws ClassNotFoundException
     */
    private DataPermission getPermission(String fullMethodName) throws ClassNotFoundException {
        int methodIndex = fullMethodName.lastIndexOf(".");
        String className = fullMethodName.substring(0, methodIndex);
        String methodName = fullMethodName.substring(methodIndex + 1);
        Class<?> targetCls = Class.forName(className);
        Method[] methods = targetCls.getMethods();
        for (Method method : methods) {
            DataPermission annotation = method.getAnnotation(DataPermission.class);
            if (Objects.nonNull(annotation) && methodName.equals(method.getName())) {
                return annotation;
            }
        }
        return null;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    } 
}

1.5、数据处理器接口及实现


/**
 * 数据权限处理
 */
public interface DataPermissionHandler {
    /**
     * 获取数据权限列表
     * @param dataPermissionType
     * @return
     */
    List<String> getDataPermissionValue(DataPermissionType dataPermissionType);

}


@Slf4j
@Service
@AllArgsConstructor
public class DataPermissionServiceImpl implements DataPermissionHandler {

    private final SysPermissionConfigService sysPermissionConfigService;

    private final VendorMasterCompanyInfoService masterCompanyInfoService;

    @Override
    public List<String> getDataPermissionValue(DataPermissionType dataPermissionType) {
        //获取登录用户信息
        CurrentService defaultCurrentService = (CurrentService) BeanTools.getBean("defaultCurrentService");
        if (Objects.isNull(defaultCurrentService)) return Collections.EMPTY_LIST;
        User user = defaultCurrentService.getUser();
        //获取用户权限
        List<String> dataPermissions = new ArrayList<>();
        switch (dataPermissionType) {
            case COMPANY:
                dataPermissions = getCompanyPermission();
                break;
            case SUPPLIER:
                dataPermissions = getSupplierPermission(user);
                break;
            default:
                break;
        }
        return dataPermissions;
    }

    /**
     * 获取供应商数据权限
     *
     * @return
     */
    private List<String> getCompanyPermission() {
        return CollectionUtil.toList("test");
    }

    /**
     * 获取公司数据权限
     *
     * @param user
     * @return
     */
    private List<String> getSupplierPermission(User user) {
        //查询公司列表信息
        Query query = QueryFactory.createQuery(false);
        query.addOpEqualFilter("account", user.getUsername());
        List<SysPermissionConfigDTO> companyList = sysPermissionConfigService.queryByParams(query);
        if (CollectionUtils.isEmpty(companyList)) return Collections.EMPTY_LIST;
        //查询公司对应供应商
        List<String> companyCodes = companyList.stream().map(SysPermissionConfigDTO::getCompanyCode).collect(Collectors.toList());
        List<VendorMasterCompanyInfoDTO> vendorCompanyList = masterCompanyInfoService.queryVendorByCompanyCodes(companyCodes);
        if (CollectionUtils.isEmpty(vendorCompanyList)) return Collections.EMPTY_LIST;
        return vendorCompanyList.stream().map(VendorMasterCompanyInfoDTO::getVendorCode).collect(Collectors.toList());
    }
}

1.6、注册拦截器

/**
 * 注入mybatis自定义插件
 */
@Configuration
@AllArgsConstructor
public class MyBatisConfig {

    private SqlSessionFactory sqlSessionFactory;

    @Bean
    public String addSqlInterceptor(DataPermissionHandler dataPermissionHandler) {
        sqlSessionFactory.getConfiguration().addInterceptor(new MyBatisInterceptor(dataPermissionHandler));
        return StringUtils.EMPTY;
    }

}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

笑谈子云亭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值