一、基本使用
1.1、定义枚举值
@AllArgsConstructor
@Getter
public enum DataPermissionType {
COMPANY,
SUPPLIER;
}
1.2、添加注解
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataPermission {
DataPermissionType type() default DataPermissionType.SUPPLIER;
String permissionCol() default "vendor_code";
}
1.3、自定义拦截器(一)
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
public class MyBatisInterceptor implements Interceptor {
public MyBatisInterceptor(DataPermissionHandler dataPermissionHandler) {
this.dataPermissionHandler = dataPermissionHandler;
}
private DataPermissionHandler dataPermissionHandler;
private DataPermission permission;
private List<String> permissionsData;
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object target = invocation.getTarget();
Method method = invocation.getMethod();
Object[] args = invocation.getArgs();
if (target instanceof Executor && method.getName().equals("query") && args.length == 4) {
MappedStatement ms = (MappedStatement) args[0];
permission = getPermission(ms.getId());
if (Objects.isNull(permission)) return invocation.proceed();
return doQuery(invocation);
}
return invocation.proceed();
}
private String getDataPermissionSql(String originalSql, DataPermission permission) throws JSQLParserException {
if (CollectionUtils.isEmpty(permissionsData)) return originalSql;
StringBuffer sb = new StringBuffer(" select ");
sb.append(getSelectItems(originalSql));
sb.append(" from ( ");
sb.append(originalSql);
sb.append(" ) a_table ");
switch (permission.type()) {
case COMPANY:
break;
case SUPPLIER:
sb.append(getSupplierSql(permissionsData));
break;
}
return sb.toString();
}
private String getSelectItems(String originalSql) throws JSQLParserException {
Statement statement = CCJSqlParserUtil.parse(originalSql);
Select select = (Select) statement;
SelectBody selectBody = select.getSelectBody();
if (selectBody instanceof PlainSelect) {
PlainSelect plainSelectBody = (PlainSelect) selectBody;
List<SelectItem> selectItems = plainSelectBody.getSelectItems();
List<SelectExpressionItem> selectExpressionItems = Arrays.asList(plainSelectBody.getSelectItems().toArray(new SelectExpressionItem[0]));
List<String> columns = new ArrayList<>();
for (SelectExpressionItem selectItem : selectExpressionItems) {
if (Objects.nonNull(selectItem.getAlias())) {
columns.add(selectItem.getAlias().getName());
continue;
}
if (Objects.nonNull(selectItem.getExpression())) {
Column column = (Column) selectItem.getExpression();
columns.add(column.getColumnName());
continue;
}
}
return StringUtils.join(columns, ",");
}
return StringUtils.EMPTY;
}
private String getSupplierSql(List<String> list) {
StringBuffer sb = new StringBuffer(" WHERE EXISTS ( SELECT vendor_code FROM VENDOR_MASTER_DATA b_table WHERE ( ");
sb.append(getWhereInSql("vendor_code", list));
sb.append(" ) AND b_table.vendor_code = a_table.vendor_code ) ");
return sb.toString();
}
private String getWhereInSql(String id, List<String> list) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < list.size(); i++) {
if (i == 0) {
sb.append(id);
sb.append(" in (");
}
sb.append("'");
sb.append(list.get(i).toString());
sb.append("'");
if (i >= 900 && i < list.size() - 1) {
if (i % 900 == 0) {
sb.append(") or ");
sb.append(id);
sb.append(" in (");
} else {
sb.append(",");
}
} else {
if (i < list.size() - 1) {
sb.append(",");
}
}
if (i == list.size() - 1) {
sb.append(")");
}
}
return sb.toString();
}
private Object doQuery(Invocation invocation) throws ClassNotFoundException, JSQLParserException,
InvocationTargetException, IllegalAccessException {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
permissionsData = dataPermissionHandler.getDataPermissionValue(permission.type());
if (CollectionUtil.isNotEmpty(permissionsData)) newBoundSql(invocation);
Object result = invocation.proceed();
permission = getPermission(ms.getId());
setColumnPermissions(result, permission);
return result;
}
private void setColumnPermissions(Object result, DataPermission dataPermission) {
if (!permission.columnPermission()) return;
permissionsData = dataPermissionHandler.getColumnPermission();
if (CollectionUtil.isEmpty(permissionsData)) return;
if (result instanceof ArrayList && Objects.nonNull(result)) {
List arrayData = (List) result;
for (Object data : arrayData) {
for (String columnName : permissionsData) {
if (!ReflectionUtils.isExistField(data.getClass(), columnName)) continue;
ReflectionUtils.setFieldValue(data, columnName, null);
}
}
}
}
private void newBoundSql(Invocation invocation) throws JSQLParserException, ClassNotFoundException {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object paramObj = args[1];
BoundSql boundSql = ms.getBoundSql(paramObj);
permission = getPermission(ms.getId());
String dataPermissionSql = getDataPermissionSql(boundSql.getSql(), permission);
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), dataPermissionSql, boundSql.getParameterMappings(),
boundSql.getParameterObject());
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
Object[] queryArgs = invocation.getArgs();
queryArgs[0] = newMs;
}
class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(),
newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
builder.keyProperty(ms.getKeyProperties()[0]);
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
private DataPermission getPermission(String fullMethodName) throws ClassNotFoundException {
int methodIndex = fullMethodName.lastIndexOf(".");
String className = fullMethodName.substring(0, methodIndex);
String methodName = fullMethodName.substring(methodIndex + 1);
Class<?> targetCls = Class.forName(className);
Method[] methods = targetCls.getMethods();
for (Method method : methods) {
DataPermission annotation = method.getAnnotation(DataPermission.class);
if (Objects.nonNull(annotation) && methodName.equals(method.getName())) {
return annotation;
}
}
return null;
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}
1.5、数据处理器接口及实现
public interface DataPermissionHandler {
List<String> getDataPermissionValue(DataPermissionType dataPermissionType);
}
@Slf4j
@Service
@AllArgsConstructor
public class DataPermissionServiceImpl implements DataPermissionHandler {
private final SysPermissionConfigService sysPermissionConfigService;
private final VendorMasterCompanyInfoService masterCompanyInfoService;
@Override
public List<String> getDataPermissionValue(DataPermissionType dataPermissionType) {
CurrentService defaultCurrentService = (CurrentService) BeanTools.getBean("defaultCurrentService");
if (Objects.isNull(defaultCurrentService)) return Collections.EMPTY_LIST;
User user = defaultCurrentService.getUser();
List<String> dataPermissions = new ArrayList<>();
switch (dataPermissionType) {
case COMPANY:
dataPermissions = getCompanyPermission();
break;
case SUPPLIER:
dataPermissions = getSupplierPermission(user);
break;
default:
break;
}
return dataPermissions;
}
private List<String> getCompanyPermission() {
return CollectionUtil.toList("test");
}
private List<String> getSupplierPermission(User user) {
Query query = QueryFactory.createQuery(false);
query.addOpEqualFilter("account", user.getUsername());
List<SysPermissionConfigDTO> companyList = sysPermissionConfigService.queryByParams(query);
if (CollectionUtils.isEmpty(companyList)) return Collections.EMPTY_LIST;
List<String> companyCodes = companyList.stream().map(SysPermissionConfigDTO::getCompanyCode).collect(Collectors.toList());
List<VendorMasterCompanyInfoDTO> vendorCompanyList = masterCompanyInfoService.queryVendorByCompanyCodes(companyCodes);
if (CollectionUtils.isEmpty(vendorCompanyList)) return Collections.EMPTY_LIST;
return vendorCompanyList.stream().map(VendorMasterCompanyInfoDTO::getVendorCode).collect(Collectors.toList());
}
}
1.6、注册拦截器
@Configuration
@AllArgsConstructor
public class MyBatisConfig {
private SqlSessionFactory sqlSessionFactory;
@Bean
public String addSqlInterceptor(DataPermissionHandler dataPermissionHandler) {
sqlSessionFactory.getConfiguration().addInterceptor(new MyBatisInterceptor(dataPermissionHandler));
return StringUtils.EMPTY;
}
}