背景
工作中肯定用到根据某某机构去过滤数据也就是权限控制
这类数据想当然就应该用 SpringAOP 去实现
思路
首先规定用什么标志让 Spring 识别当前方法需要数据过滤
优雅的使用自定义注解是个不错的选择
注解里面传入当前要进行数据过滤的 Class 对象即可
业务实体类对象里面需要声明一份公共的 OrgIds List 格式用来存放过滤后的查询条件
自定义 DataFilter 注解
package ;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import cn.hutool.core.util.ReflectUtil;
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface DataFilter {
Class<?> value();
FilterStrategy[] filterField() default FilterStrategy.orgIds;
enum FilterStrategy {
orgIds,
projectIds;
}
enum OrgQueryParam {
orgId,
did,
createOrgId,
orgName;
public static Object fieldHasValue(Object obj) {
for (OrgQueryParam queryParam : OrgQueryParam.values()) {
Object fieldValue = ReflectUtil.getFieldValue(obj, queryParam.name());
if (fieldValue != null) {
return fieldValue;
}
}
return null;
}
public static void fieldSetNull(Object obj) {
for (OrgQueryParam queryParam : OrgQueryParam.values()) {
Object fieldValue = ReflectUtil.getFieldValue(obj, queryParam.name());
if (fieldValue != null) {
ReflectUtil.setFieldValue(obj, queryParam.name(), null);
}
}
}
}
}
数据过滤切面
package ;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import com.alibaba.fastjson.JSON;
import cn.com.insurance.admin.internal.annotation.DataFilter;
import cn.com.insurance.admin.internal.annotation.DataFilter.FilterStrategy;
import cn.com.insurance.admin.internal.annotation.DataFilter.OrgQueryParam;
import cn.com.insurance.admin.internal.shiro.ShiroContext;
import cn.com.insurance.admin.internal.utils.StringUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.ReflectUtil;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@Aspect
@Component
public class DataFilterAspect {
@Pointcut("@annotation(cn.com.insurance.admin.internal.annotation.DataFilter)")
public void operatePointExpression() {
}
@Around("operatePointExpression()")
public Object aroundMethod(ProceedingJoinPoint joinPoint) throws Throwable {
try {
Object[] args = joinPoint.getArgs();
boolean isAdmin = ShiroContext.getCurrUser() == null || ShiroContext.getCurrUser().isSuperAdmin();
if (isAdmin) {
log.debug("数据过滤执行结果: 当前登录用户是超级管理员: {},过滤参数是: \r\n{}",
ShiroContext.getCurrUser(),
JSON.toJSONString(args, true));
}
Class<?> classTarget = joinPoint.getTarget().getClass();
Class<?>[] par = ((MethodSignature) joinPoint.getSignature()).getParameterTypes();
Method objMethod = classTarget.getMethod(joinPoint.getSignature().getName(), par);
DataFilter annotation = objMethod.getAnnotation(DataFilter.class);
doDataFilter(args, annotation, isAdmin);
log.debug("数据过滤执行结果: 当前登录用户是: {},过滤参数是: \r\n{}", ShiroContext.getCurrUser(), JSON.toJSONString(args, true));
} catch (Exception e) {
log.error("数据过滤异常: {}", e);
}
return joinPoint.proceed();
}
@SuppressWarnings("unchecked")
private void doDataFilter(Object[] args, DataFilter annotation, boolean isAdmin) {
Class<? extends Object> clazz = annotation.value();
for (Object obj : args) {
if (obj.getClass() != clazz) {
continue;
}
Object queryFieldHasValue = OrgQueryParam.fieldHasValue(obj);
Object hasSubOrg = ReflectUtil.getFieldValue(obj, "hasSubOrg");
if (hasSubOrg != null
&& (boolean) hasSubOrg
&& queryFieldHasValue != null
&& ShiroContext.getSubOrgs((String) queryFieldHasValue).size() > 1) {
ReflectUtil.invoke(obj,
"set" + StringUtil.upperFirstChar(FilterStrategy.orgIds.name()),
CollectionUtil.isNotEmpty(ShiroContext.getOrgIds()) && !isAdmin ? ShiroContext
.getOrgIds() : ShiroContext.getSubOrgs((String) queryFieldHasValue));
OrgQueryParam.fieldSetNull(obj);
} else {
if (queryFieldHasValue != null && queryFieldHasValue.toString().trim().length() > 0) {
ReflectUtil.invoke(obj,
"set" + StringUtil.upperFirstChar(FilterStrategy.orgIds.name()),
Arrays.asList(queryFieldHasValue));
OrgQueryParam.fieldSetNull(obj);
}
}
List<String> methods = Arrays.stream(ReflectUtil.getMethodsDirectly(obj.getClass(), true))
.map(Method::getName)
.collect(Collectors.toList());
for (FilterStrategy strategy : annotation.filterField()) {
String field = strategy.name();
String methodName = "set" + StringUtil.upperFirstChar(field);
if (methods.contains(methodName) && field.contains(FilterStrategy.orgIds.name())) {
Object fieldValue = ReflectUtil.getFieldValue(obj, field);
if (!isAdmin
&& (fieldValue == null
|| (fieldValue instanceof Collection
&& ((Collection<String>) fieldValue).isEmpty()))) {
ReflectUtil.invoke(obj, methodName, ShiroContext.getOrgIds());
OrgQueryParam.fieldSetNull(obj);
}
} else if (methods.contains(methodName) && field.contains(FilterStrategy.projectIds.name())) {
ReflectUtil.invoke(obj, methodName, ShiroContext.getProjectIds());
}
}
}
}
}
公共Vo
package ;
import java.io.Serializable;
import java.util.List;
import javax.persistence.Transient;
import com.alibaba.fastjson.annotation.JSONField;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import cn.com.insurance.admin.utils.response.Pagination;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import lombok.experimental.Accessors;
@ToString
@Accessors(chain = true)
@EqualsAndHashCode
@JsonIgnoreProperties("pagination")
public class CommonVo implements Serializable {
private static final long serialVersionUID = 1L;
@Transient
@JsonInclude(JsonInclude.Include.NON_NULL)
private Integer pageNo, pageSize;
@Transient
@JSONField(serialize = false)
private Pagination pagination = new Pagination();
@Transient
private List<String> projectIds;
@Transient
private List<String> orgIds;
@Transient
private boolean hasSubOrg;
public Integer getPageNo() {
return pageNo;
}
public void setPageNo(Integer pageNo) {
this.pageNo = pageNo;
this.pagination.setCurrentPage(pageNo);
}
public Integer getPageSize() {
return pageSize;
}
public void setPageSize(Integer pageSize) {
this.pageSize = pageSize;
this.pagination.setPageSize(pageSize);
}
public Pagination getPagination() {
return pagination;
}
public void setPagination(Pagination pagination) {
this.pagination = pagination;
}
public List<String> getProjectIds() {
return projectIds;
}
public void setProjectIds(List<String> projectIds) {
this.projectIds = projectIds;
}
public List<String> getOrgIds() {
return orgIds;
}
public void setOrgIds(List<String> orgIds) {
this.orgIds = orgIds;
}
public boolean isHasSubOrg() {
return hasSubOrg;
}
public void setHasSubOrg(boolean hasSubOrg) {
this.hasSubOrg = hasSubOrg;
}
}
公共 Pagination
package ;
import java.io.Serializable;
import java.util.Objects;
public class Pagination implements Serializable {
private static final long serialVersionUID = 6787772888002596631L;
private int currentIndex = 0;
private int currentPage = 1;
private int totalRecords = -1;
private int pageSize = 10;
private SortItem sortItem = null;
public Pagination() {
}
private boolean validate() {
return this.totalRecords > -1;
}
public int getCurrentPage() {
return this.currentPage;
}
public void setCurrentPage(int currentPage) {
if (currentPage <= 0) {
throw new IllegalArgumentException("the parameter 'currentPage' can not be zero or negative.");
}
this.currentPage = currentPage;
if (validate()) {
prepare();
}
}
public int getCurrentIndex() {
if (validate()) {
prepare();
}
return this.currentIndex;
}
public int getTotalRecodes() {
return this.totalRecords;
}
public void setTotalRecodes(int totalRecords) {
if (totalRecords < 0) {
totalRecords = 10;
}
this.totalRecords = totalRecords;
prepare();
}
public int getPageSize() {
return this.pageSize;
}
public void setPageSize(int pageSize) {
if (pageSize <= 0) {
throw new IllegalArgumentException("the parameter 'pageSize' can not be zero or negative.");
}
this.pageSize = pageSize;
}
private void prepare() {
if (this.totalRecords < 0) {
totalRecords = 10;
}
if (this.currentPage == 0) {
this.currentPage = 1;
}
if (this.pageSize == 0) {
this.pageSize = 10;
}
if (this.currentPage > this.getTotalPage()) {
this.currentPage = this.getTotalPage();
}
this.currentIndex = ((this.currentPage - 1) * this.pageSize);
if (this.currentIndex > this.totalRecords) {
this.currentIndex = (this.totalRecords - (this.currentIndex - this.totalRecords));
}
}
public int getTotalPage() {
if (this.totalRecords <= this.pageSize) {
return 1;
}
int m = this.totalRecords % this.pageSize;
int totalPage = (this.totalRecords - m) / this.pageSize;
if (m > 0) {
totalPage++;
}
return totalPage;
}
public boolean isInitialized() {
return this.validate();
}
public void init(int totalRecords) {
this.setTotalRecodes(totalRecords);
}
public SortItem getSortItem() {
return sortItem;
}
public void setSortItem(SortItem sortItem) {
this.sortItem = sortItem;
}
public boolean hasNext() {
int nextPage = this.currentPage + 1;
int totalPage = getTotalPage();
return nextPage <= totalPage;
}
public void next() {
if (hasNext()) {
this.currentPage += 1;
}
}
@Override
public String toString() {
return "Pagination{"
+ "currentIndex="
+ currentIndex
+ ", currentPage="
+ currentPage
+ ", totalRecords="
+ totalRecords
+ ", pageSize="
+ pageSize
+ ", sortItem="
+ sortItem
+ '}';
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Pagination that = (Pagination) o;
return currentIndex == that.currentIndex
&& currentPage == that.currentPage
&& totalRecords == that.totalRecords
&& pageSize == that.pageSize
&& Objects.equals(sortItem, that.sortItem);
}
@Override
public int hashCode() {
return Objects.hash(currentIndex, currentPage, totalRecords, pageSize, sortItem);
}
}
如何使用
@DataFilter(CoreAccountingBill.class)
@Override
public List<CoreAccountingBill> queryCoreAccountingBillList(CoreAccountingBill coreAccountingBill, Pagination pagination) {
PageInfo<CoreAccountingBill> pageInfo = PageHelper
.startPage(pagination.getCurrentPage(), pagination.getPageSize())
.doSelectPageInfo(() -> coreAccountingBillMapper.queryCoreAccountingBillList(coreAccountingBill));
pagination.setTotalRecodes((int) pageInfo.getTotal());
return pageInfo.getList();
}