背景
工作中肯定用到根据某某机构去过滤数据也就是权限控制
这类数据想当然就应该用 SpringAOP 去实现
思路
首先规定用什么标志让 Spring 识别当前方法需要数据过滤
优雅的使用自定义注解是个不错的选择
注解里面传入当前要进行数据过滤的 Class 对象即可
业务实体类对象里面需要声明一份公共的 OrgIds List 格式用来存放过滤后的查询条件
自定义 DataFilter 注解
package ;
import java. lang. annotation. Documented ;
import java. lang. annotation. ElementType ;
import java. lang. annotation. Retention ;
import java. lang. annotation. RetentionPolicy ;
import java. lang. annotation. Target ;
import cn. hutool. core. util. ReflectUtil ;
@Documented
@Retention ( RetentionPolicy . RUNTIME)
@Target ( ElementType . METHOD)
public @interface DataFilter {
Class < ? > value ( ) ;
FilterStrategy [ ] filterField ( ) default FilterStrategy . orgIds;
enum FilterStrategy {
orgIds,
projectIds;
}
enum OrgQueryParam {
orgId,
did,
createOrgId,
orgName;
public static Object fieldHasValue ( Object obj) {
for ( OrgQueryParam queryParam : OrgQueryParam . values ( ) ) {
Object fieldValue = ReflectUtil . getFieldValue ( obj, queryParam. name ( ) ) ;
if ( fieldValue != null ) {
return fieldValue;
}
}
return null ;
}
public static void fieldSetNull ( Object obj) {
for ( OrgQueryParam queryParam : OrgQueryParam . values ( ) ) {
Object fieldValue = ReflectUtil . getFieldValue ( obj, queryParam. name ( ) ) ;
if ( fieldValue != null ) {
ReflectUtil . setFieldValue ( obj, queryParam. name ( ) , null ) ;
}
}
}
}
}
数据过滤切面
package ;
import java. lang. reflect. Method ;
import java. util. Arrays ;
import java. util. Collection ;
import java. util. List ;
import java. util. stream. Collectors ;
import org. aspectj. lang. ProceedingJoinPoint ;
import org. aspectj. lang. annotation. Around ;
import org. aspectj. lang. annotation. Aspect ;
import org. aspectj. lang. annotation. Pointcut ;
import org. aspectj. lang. reflect. MethodSignature ;
import org. springframework. stereotype. Component ;
import com. alibaba. fastjson. JSON;
import cn. com. insurance. admin. internal. annotation. DataFilter ;
import cn. com. insurance. admin. internal. annotation. DataFilter. FilterStrategy ;
import cn. com. insurance. admin. internal. annotation. DataFilter. OrgQueryParam ;
import cn. com. insurance. admin. internal. shiro. ShiroContext ;
import cn. com. insurance. admin. internal. utils. StringUtil ;
import cn. hutool. core. collection. CollectionUtil ;
import cn. hutool. core. util. ReflectUtil ;
import lombok. extern. slf4j. Slf4j ;
@Slf4j
@Aspect
@Component
public class DataFilterAspect {
@Pointcut ( "@annotation(cn.com.insurance.admin.internal.annotation.DataFilter)" )
public void operatePointExpression ( ) {
}
@Around ( "operatePointExpression()" )
public Object aroundMethod ( ProceedingJoinPoint joinPoint) throws Throwable {
try {
Object [ ] args = joinPoint. getArgs ( ) ;
boolean isAdmin = ShiroContext . getCurrUser ( ) == null || ShiroContext . getCurrUser ( ) . isSuperAdmin ( ) ;
if ( isAdmin) {
log. debug ( "数据过滤执行结果: 当前登录用户是超级管理员: {},过滤参数是: \r\n{}" ,
ShiroContext . getCurrUser ( ) ,
JSON. toJSONString ( args, true ) ) ;
}
Class < ? > classTarget = joinPoint. getTarget ( ) . getClass ( ) ;
Class < ? > [ ] par = ( ( MethodSignature ) joinPoint. getSignature ( ) ) . getParameterTypes ( ) ;
Method objMethod = classTarget. getMethod ( joinPoint. getSignature ( ) . getName ( ) , par) ;
DataFilter annotation = objMethod. getAnnotation ( DataFilter . class ) ;
doDataFilter ( args, annotation, isAdmin) ;
log. debug ( "数据过滤执行结果: 当前登录用户是: {},过滤参数是: \r\n{}" , ShiroContext . getCurrUser ( ) , JSON. toJSONString ( args, true ) ) ;
} catch ( Exception e) {
log. error ( "数据过滤异常: {}" , e) ;
}
return joinPoint. proceed ( ) ;
}
@SuppressWarnings ( "unchecked" )
private void doDataFilter ( Object [ ] args, DataFilter annotation, boolean isAdmin) {
Class < ? extends Object > clazz = annotation. value ( ) ;
for ( Object obj : args) {
if ( obj. getClass ( ) != clazz) {
continue ;
}
Object queryFieldHasValue = OrgQueryParam . fieldHasValue ( obj) ;
Object hasSubOrg = ReflectUtil . getFieldValue ( obj, "hasSubOrg" ) ;
if ( hasSubOrg != null
&& ( boolean ) hasSubOrg
&& queryFieldHasValue != null
&& ShiroContext . getSubOrgs ( ( String ) queryFieldHasValue) . size ( ) > 1 ) {
ReflectUtil . invoke ( obj,
"set" + StringUtil . upperFirstChar ( FilterStrategy . orgIds. name ( ) ) ,
CollectionUtil . isNotEmpty ( ShiroContext . getOrgIds ( ) ) && ! isAdmin ? ShiroContext
. getOrgIds ( ) : ShiroContext . getSubOrgs ( ( String ) queryFieldHasValue) ) ;
OrgQueryParam . fieldSetNull ( obj) ;
} else {
if ( queryFieldHasValue != null && queryFieldHasValue. toString ( ) . trim ( ) . length ( ) > 0 ) {
ReflectUtil . invoke ( obj,
"set" + StringUtil . upperFirstChar ( FilterStrategy . orgIds. name ( ) ) ,
Arrays . asList ( queryFieldHasValue) ) ;
OrgQueryParam . fieldSetNull ( obj) ;
}
}
List < String > methods = Arrays . stream ( ReflectUtil . getMethodsDirectly ( obj. getClass ( ) , true ) )
. map ( Method :: getName )
. collect ( Collectors . toList ( ) ) ;
for ( FilterStrategy strategy : annotation. filterField ( ) ) {
String field = strategy. name ( ) ;
String methodName = "set" + StringUtil . upperFirstChar ( field) ;
if ( methods. contains ( methodName) && field. contains ( FilterStrategy . orgIds. name ( ) ) ) {
Object fieldValue = ReflectUtil . getFieldValue ( obj, field) ;
if ( ! isAdmin
&& ( fieldValue == null
|| ( fieldValue instanceof Collection
&& ( ( Collection < String > ) fieldValue) . isEmpty ( ) ) ) ) {
ReflectUtil . invoke ( obj, methodName, ShiroContext . getOrgIds ( ) ) ;
OrgQueryParam . fieldSetNull ( obj) ;
}
} else if ( methods. contains ( methodName) && field. contains ( FilterStrategy . projectIds. name ( ) ) ) {
ReflectUtil . invoke ( obj, methodName, ShiroContext . getProjectIds ( ) ) ;
}
}
}
}
}
公共Vo
package ;
import java. io. Serializable ;
import java. util. List ;
import javax. persistence. Transient ;
import com. alibaba. fastjson. annotation. JSONField ;
import com. fasterxml. jackson. annotation. JsonIgnoreProperties ;
import com. fasterxml. jackson. annotation. JsonInclude ;
import cn. com. insurance. admin. utils. response. Pagination ;
import lombok. EqualsAndHashCode ;
import lombok. ToString ;
import lombok. experimental. Accessors ;
@ToString
@Accessors ( chain = true )
@EqualsAndHashCode
@JsonIgnoreProperties ( "pagination" )
public class CommonVo implements Serializable {
private static final long serialVersionUID = 1L ;
@Transient
@JsonInclude ( JsonInclude. Include . NON_NULL)
private Integer pageNo, pageSize;
@Transient
@JSONField ( serialize = false )
private Pagination pagination = new Pagination ( ) ;
@Transient
private List < String > projectIds;
@Transient
private List < String > orgIds;
@Transient
private boolean hasSubOrg;
public Integer getPageNo ( ) {
return pageNo;
}
public void setPageNo ( Integer pageNo) {
this . pageNo = pageNo;
this . pagination. setCurrentPage ( pageNo) ;
}
public Integer getPageSize ( ) {
return pageSize;
}
public void setPageSize ( Integer pageSize) {
this . pageSize = pageSize;
this . pagination. setPageSize ( pageSize) ;
}
public Pagination getPagination ( ) {
return pagination;
}
public void setPagination ( Pagination pagination) {
this . pagination = pagination;
}
public List < String > getProjectIds ( ) {
return projectIds;
}
public void setProjectIds ( List < String > projectIds) {
this . projectIds = projectIds;
}
public List < String > getOrgIds ( ) {
return orgIds;
}
public void setOrgIds ( List < String > orgIds) {
this . orgIds = orgIds;
}
public boolean isHasSubOrg ( ) {
return hasSubOrg;
}
public void setHasSubOrg ( boolean hasSubOrg) {
this . hasSubOrg = hasSubOrg;
}
}
公共 Pagination
package ;
import java. io. Serializable ;
import java. util. Objects ;
public class Pagination implements Serializable {
private static final long serialVersionUID = 6787772888002596631L ;
private int currentIndex = 0 ;
private int currentPage = 1 ;
private int totalRecords = - 1 ;
private int pageSize = 10 ;
private SortItem sortItem = null ;
public Pagination ( ) {
}
private boolean validate ( ) {
return this . totalRecords > - 1 ;
}
public int getCurrentPage ( ) {
return this . currentPage;
}
public void setCurrentPage ( int currentPage) {
if ( currentPage <= 0 ) {
throw new IllegalArgumentException ( "the parameter 'currentPage' can not be zero or negative." ) ;
}
this . currentPage = currentPage;
if ( validate ( ) ) {
prepare ( ) ;
}
}
public int getCurrentIndex ( ) {
if ( validate ( ) ) {
prepare ( ) ;
}
return this . currentIndex;
}
public int getTotalRecodes ( ) {
return this . totalRecords;
}
public void setTotalRecodes ( int totalRecords) {
if ( totalRecords < 0 ) {
totalRecords = 10 ;
}
this . totalRecords = totalRecords;
prepare ( ) ;
}
public int getPageSize ( ) {
return this . pageSize;
}
public void setPageSize ( int pageSize) {
if ( pageSize <= 0 ) {
throw new IllegalArgumentException ( "the parameter 'pageSize' can not be zero or negative." ) ;
}
this . pageSize = pageSize;
}
private void prepare ( ) {
if ( this . totalRecords < 0 ) {
totalRecords = 10 ;
}
if ( this . currentPage == 0 ) {
this . currentPage = 1 ;
}
if ( this . pageSize == 0 ) {
this . pageSize = 10 ;
}
if ( this . currentPage > this . getTotalPage ( ) ) {
this . currentPage = this . getTotalPage ( ) ;
}
this . currentIndex = ( ( this . currentPage - 1 ) * this . pageSize) ;
if ( this . currentIndex > this . totalRecords) {
this . currentIndex = ( this . totalRecords - ( this . currentIndex - this . totalRecords) ) ;
}
}
public int getTotalPage ( ) {
if ( this . totalRecords <= this . pageSize) {
return 1 ;
}
int m = this . totalRecords % this . pageSize;
int totalPage = ( this . totalRecords - m) / this . pageSize;
if ( m > 0 ) {
totalPage++ ;
}
return totalPage;
}
public boolean isInitialized ( ) {
return this . validate ( ) ;
}
public void init ( int totalRecords) {
this . setTotalRecodes ( totalRecords) ;
}
public SortItem getSortItem ( ) {
return sortItem;
}
public void setSortItem ( SortItem sortItem) {
this . sortItem = sortItem;
}
public boolean hasNext ( ) {
int nextPage = this . currentPage + 1 ;
int totalPage = getTotalPage ( ) ;
return nextPage <= totalPage;
}
public void next ( ) {
if ( hasNext ( ) ) {
this . currentPage += 1 ;
}
}
@Override
public String toString ( ) {
return "Pagination{"
+ "currentIndex="
+ currentIndex
+ ", currentPage="
+ currentPage
+ ", totalRecords="
+ totalRecords
+ ", pageSize="
+ pageSize
+ ", sortItem="
+ sortItem
+ '}' ;
}
@Override
public boolean equals ( Object o) {
if ( this == o) {
return true ;
}
if ( o == null || getClass ( ) != o. getClass ( ) ) {
return false ;
}
Pagination that = ( Pagination ) o;
return currentIndex == that. currentIndex
&& currentPage == that. currentPage
&& totalRecords == that. totalRecords
&& pageSize == that. pageSize
&& Objects . equals ( sortItem, that. sortItem) ;
}
@Override
public int hashCode ( ) {
return Objects . hash ( currentIndex, currentPage, totalRecords, pageSize, sortItem) ;
}
}
如何使用
@DataFilter ( CoreAccountingBill . class )
@Override
public List < CoreAccountingBill > queryCoreAccountingBillList ( CoreAccountingBill coreAccountingBill, Pagination pagination) {
PageInfo < CoreAccountingBill > pageInfo = PageHelper
. startPage ( pagination. getCurrentPage ( ) , pagination. getPageSize ( ) )
. doSelectPageInfo ( ( ) -> coreAccountingBillMapper. queryCoreAccountingBillList ( coreAccountingBill) ) ;
pagination. setTotalRecodes ( ( int ) pageInfo. getTotal ( ) ) ;
return pageInfo. getList ( ) ;
}