基于mybatis Interceptor 机制实现的自动添加数据过滤组件。
背景
- 我们的系统要求所有的数据都是按照系统编码过滤,那么查询sql都需要写上SYS_CODE IN (所属系统编码)这样的sql
- 我们的系统都要求按创建人来过滤,那么所有的查询、修改、删除都需要加上创建人的条件
- 使用我们的组件就可以不需要改动sql,自动的完成条件的添加。就像是现在很流行的pagehelper 分页插件一样的简单
核心类
- AbstractSqlInterceptor 继承mybatis的Interceptor 的抽象拦截器
- DataFilterSqlInterceptor 继承AbstractSqlInterceptor 的数据过滤组件
- DataFilterContextUtil 过滤配置上下文
- MybatisDataFilterProperties 配置文件
- AbstractAopPointcut 抽象的spring aop拦截器。拦截@dataFilter 注解
- DataFilterAopPointcut 继承AbstractAopPointcut 的数据过滤aop拦截器。设置@dataFilter的配置到DataFilterContextUtil 中,并且调用DataFilterValueHandle 设置模板变量值
- DataFilterValueHandle 业务系统集成设置模板替换的值
核心代码
AbstractSqlInterceptor 继承Interceptor 的抽象拦截器
package mybatis.interceptor;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Properties;
/**
* 抽象sql拦截器,拦截的是Executor.class
*/
public abstract class AbstractSqlInterceptor implements Interceptor {
public Logger logger = LoggerFactory.getLogger("mybatisDataFilter");
/**
* 判断是否执行 子类继承
* @param invocation
* @return
*/
public abstract boolean match(Invocation invocation);
/**
* 构建新sql 子类继承
* @param boundSql
* @param origSql
* @return
*/
public abstract String buildSql(BoundSql boundSql,String origSql);
@Override
public Object intercept(Invocation invocation) throws Throwable {
if (!match(invocation)){
return invocation.proceed();
}
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
// id为执行的mapper方法的全路径名,如com.mapper.UserMapper
// String id = ms.getId();
// sql语句类型 select、delete、insert、update
String sqlCommandType = ms.getSqlCommandType().toString();
// 仅拦截 select 查询
if (!sqlCommandType.equals(SqlCommandType.SELECT.toString())) {
return invocation.proceed();
}
BoundSql boundSql = ms.getBoundSql(parameterObject);
String origSql = boundSql.getSql();
logger.debug("原始SQL: {}", origSql);
// 组装新的 sql
String newSql = buildSql(boundSql,origSql);
// 重新new一个查询语句对象
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
boundSql.getParameterMappings(), boundSql.getParameterObject());
// 把新的查询放到statement里
MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
Object[] queryArgs = invocation.getArgs();
queryArgs[0] = newMs;
logger.debug("改写的SQL: {}", newSql);
return invocation.proceed();
}
/**
* 定义一个内部辅助类,作用是包装 SQL
*/
class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
private MappedStatement newMappedStatement (MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new
MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
builder.keyProperty(ms.getKeyProperties()[0]);
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}
DataFilterSqlInterceptor 继承AbstractSqlInterceptor 的数据过滤组件
package mybatis.interceptor;
import com.alibaba.fastjson.JSON;
import mybatis.properties.MybatisDataFilterProperties;
import mybatis.util.DataFilterContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
/**
* 系统编码sql自动生成拦截器
*
* 配置优先级:DataFilter ->DataFilterContextUtil -> MybatisDataFilterProperties
*/
@Intercepts({@Signature(type = Executor.class, method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class,ResultHandler.class})})
@Component
@Slf4j
public class DataFilterSqlInterceptor extends AbstractSqlInterceptor{
@Autowired
private DataFilterContextUtil dataFilterContextUtil;
private final static String WHERE_SQL = "WHERE SYS_CODE IN (${SYS_CODE})";
private String whereSqlFormat;
private final static String AND = "AND";
private final static String SELECT_FORMAT = "SELECT * FROM (${SQL}) t WHERE ${CONDITION}";
private String selectFormat;
private final static String POINT = ".";
//检查过滤key
private final static String CHECK_KEY[] = {"ORDER","LIMIT"};
private final static String WHERE = "WHERE";
private final static String ALIAS_NAME = "ALIAS_NAME";
private final static String ALIAS_NAME_FORMAT = "${ALIAS_NAME}";
@Autowired
MybatisDataFilterProperties mybatisDataFilterProperties;
@Override
public boolean match(Invocation invocation) {
if (!dataFilterContextUtil.isDataFilterContext()){
return false;
}
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
Boolean isAdmin = dataFilterContext.getIsAdmin() == null ? false : dataFilterContext.getIsAdmin();
//是管理员
if (getAdminMode() && isAdmin){
logger.debug("----------admin is not filter");
return false;
}
return true;
}
@Override
public String buildSql(BoundSql boundSql, String origSql) {
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
logger.debug("开始构建sql mod:{}",getMode());
// List<UserSystemDetailDTO> userList = (List<UserSystemDetailDTO>) ApplicationContextUtil.getValue(ApplicationContextUtil.USER_SYS_CODE);
// String sysCodes = ZkSysCodeFactory.getUserListStr(userList);
//转换大写
String upperSql = origSql.toUpperCase();
String newSql = origSql;
Map<String,String> valueMap = dataFilterContext.getValueMap();
logger.debug("buildSql valueMap:{} format:{}",JSON.toJSONString(valueMap),dataFilterContextUtil.getDataFilterContext().getWhereSqlFormat());
if (valueMap == null || valueMap.isEmpty()){
logger.warn("valueMap is null 不进行sql处理");
return "";
}
if (DataFilterContextUtil.SysCodeSqlModeEnum.WHERE.equals(getMode())){
newSql = buildWhereSql(upperSql,valueMap);
}else if (DataFilterContextUtil.SysCodeSqlModeEnum.SELECT.equals(getMode())){
newSql = buildSelectSql(upperSql,valueMap);
}else {
throw new RuntimeException("mode is not WHERE OR SELECT");
}
return newSql;
}
/**
* select模式的sql生成
*
* 生成select * ( 原sql) t where $whereCondition
* @param origSql
* @param valueMap
* @return
*/
private String buildSelectSql(String origSql,Map<String,String> valueMap){
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
dataFilterContext.setAliaName("t");
//条件
String whereCondition = buildWhereCondition(valueMap);
logger.debug("改写的whereCondition:"+whereCondition);
String selFormat = getSelectFormat();
//去除where
selFormat = selFormat.replace(WHERE,"");
selFormat = selFormat.replace(WHERE.toLowerCase(),"");
if (whereCondition.indexOf(WHERE) == -1 && whereCondition.indexOf(WHERE.toLowerCase()) == -1){
whereCondition = WHERE+" "+whereCondition;
}
//select模板数据
Map<String,String> selectValueMap = new HashMap<>();
selectValueMap.put("SQL",origSql);
selectValueMap.put("CONDITION",whereCondition);
// return String.format(SELECT_FORMAT,origSql)+""+whereCondition;
return format(selFormat,selectValueMap);
}
/**
* 将 where 条件替换成 WHERE SYS_CODE IN () 格式
* 如果没有where 则新生成一个where
* @param origSql
* @return
*/
private String buildWhereSql(String origSql,Map<String,String> valueMap){
//where 条件
String whereCondition = buildWhereCondition(valueMap);
logger.debug("改写的whereCondition:"+whereCondition);
if (whereCondition.indexOf(WHERE) == -1 && whereCondition.indexOf(WHERE.toLowerCase()) == -1){
whereCondition = WHERE+" "+whereCondition;
}
if (origSql.contains(WHERE)){
whereCondition = whereCondition+" "+AND+" ";
//替换where
return origSql.replaceFirst(WHERE,whereCondition);
}else {
//添加where
return addWhereCondition(origSql,whereCondition);
}
}
/**
* 构建where 条件
* @return
*/
private String buildWhereCondition(Map<String,String> valueMap){
if (valueMap == null || valueMap.isEmpty()){
return "";
}
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
String aliaName = dataFilterContext.getAliaName();
if (StringUtils.isNotEmpty(aliaName)){
valueMap.put(ALIAS_NAME,aliaName);
}
logger.debug("buildWhereCondition valueMap:{} format:{}",JSON.toJSONString(valueMap),dataFilterContextUtil.getDataFilterContext().getWhereSqlFormat());
return format(getWhereSqlFormat(),valueMap);
}
/**
* 格式化
* @param format
* @param valueMap
* @return
*/
private String format(String format,Map<String,String> valueMap){
StringSubstitutor sub = new StringSubstitutor(valueMap);
String formatSql = sub.replace(format);
return formatSql;
}
/**
* 添加where条件
* select * from tt order/limit -> select * from tt $whereCondition order/limit
* @param origSql
* @param whereCondition
* @return
*/
public String addWhereCondition(String origSql,String whereCondition){
//如果包含了CHECK_KEY的字段,则where条件在其之前
for (String key : CHECK_KEY){
int index = origSql.indexOf(key);
if (index != -1){
String subStr = origSql.substring(0,index);
String lastSubStr = origSql.substring(index,origSql.length());
return subStr+" "+whereCondition+lastSubStr;
}
}
return origSql + " "+whereCondition;
}
public String getWhereSqlFormat() {
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
whereSqlFormat = dataFilterContext.getWhereSqlFormat();
if (StringUtils.isNotEmpty(whereSqlFormat)){
String whereSql = aliasNameHandle(whereSqlFormat);
logger.debug("whereSqlFormat sql:{}",whereSqlFormat);
return whereSql;
}
whereSqlFormat = mybatisDataFilterProperties.getWhereSqlFormat();
logger.debug("whereSqlFormat:{}",whereSqlFormat);
if (StringUtils.isNotEmpty(whereSqlFormat)){
String whereSql = aliasNameHandle(whereSqlFormat);
logger.debug("whereSqlFormat sql:{}",whereSqlFormat);
return whereSql;
}
return WHERE_SQL;
}
private String aliasNameHandle(String whereSql){
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
if (whereSql != null && StringUtils.isEmpty(dataFilterContext.getAliaName())){
return whereSql.replace(ALIAS_NAME_FORMAT+POINT,"");
}
return whereSql;
}
public DataFilterContextUtil.SysCodeSqlModeEnum getMode(){
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
DataFilterContextUtil.SysCodeSqlModeEnum sysCodeSqlModeEnum = dataFilterContext.getMode();
if (sysCodeSqlModeEnum == null){
String globleMode = mybatisDataFilterProperties.getMode();
if (StringUtils.isNotEmpty(globleMode)){
dataFilterContextUtil.setMode(globleMode);
sysCodeSqlModeEnum = dataFilterContext.getMode();
}
}
return sysCodeSqlModeEnum == null ? DataFilterContextUtil.SysCodeSqlModeEnum.WHERE : sysCodeSqlModeEnum;
}
public String getSelectFormat() {
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
selectFormat = dataFilterContext.getSelectFormat();
if (StringUtils.isNotEmpty(selectFormat)){
return selectFormat;
}
selectFormat = mybatisDataFilterProperties.getSelectFormat();
if (StringUtils.isNotEmpty(selectFormat)){
return selectFormat;
}
return SELECT_FORMAT;
}
public boolean getAdminMode(){
DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
Boolean adminMode = dataFilterContext.getAdminMode();
if (adminMode == null){
adminMode = mybatisDataFilterProperties.getAdminMode();
if (adminMode != null){
dataFilterContextUtil.setAdminMode(adminMode);
adminMode = dataFilterContext.getAdminMode();
}
}
return adminMode == null ? true : adminMode;
}
/*public static void main(String[] args) {
String origSql = "SELECT * FROM AA Inner join b on AA.A=B.B ORDER AA.KEY DESC Limit 10";
for (String key : CHECK_KEY){
int index = origSql.indexOf(key);
if (index != -1){
String subStr = origSql.substring(0,index);
String lastSubStr = origSql.substring(index,origSql.length());
System.out.println(subStr);
System.out.println(lastSubStr);
}
}
}*/
public DataFilterContextUtil getDataFilterContextUtil() {
return dataFilterContextUtil;
}
public void setDataFilterContextUtil(DataFilterContextUtil dataFilterContextUtil) {
this.dataFilterContextUtil = dataFilterContextUtil;
}
public static String getWhereSql() {
return WHERE_SQL;
}
public void setWhereSqlFormat(String whereSqlFormat) {
this.whereSqlFormat = whereSqlFormat;
}
public void setSelectFormat(String selectFormat) {
this.selectFormat = selectFormat;
}
public MybatisDataFilterProperties getMybatisDataFilterProperties() {
return mybatisDataFilterProperties;
}
public void setMybatisDataFilterProperties(MybatisDataFilterProperties mybatisDataFilterProperties) {
this.mybatisDataFilterProperties = mybatisDataFilterProperties;
}
}
DataFilterContextUtil 过滤配置上下文
package mybatis.util;
import lombok.Data;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
@Component
public class DataFilterContextUtil {
private static ThreadLocal<DataFilterContext> threadLocal = new ThreadLocal();
public DataFilterContextUtil init(){
DataFilterContext dataFilterContext = new DataFilterContext();
threadLocal.set(dataFilterContext);
return this;
}
public DataFilterContextUtil setMode(SysCodeSqlModeEnum mode){
DataFilterContext dataFilterContext = getDataFilterContext();
dataFilterContext.setMode(mode);
return this;
}
public DataFilterContextUtil setMode(String mode){
DataFilterContext dataFilterContext = getDataFilterContext();
SysCodeSqlModeEnum sysCodeSqlModeEnum =SysCodeSqlModeEnum.valueOf(mode.toUpperCase());
dataFilterContext.setMode(sysCodeSqlModeEnum);
return this;
}
public DataFilterContextUtil setAliaName(String aliaName){
DataFilterContext dataFilterContext = getDataFilterContext();
dataFilterContext.setAliaName(aliaName);
return this;
}
public DataFilterContextUtil setValue(String name,String value){
DataFilterContext dataFilterContext = getDataFilterContext();
Map<String,String> valueMap = dataFilterContext.getValueMap();
valueMap.put(name,value);
return this;
}
public DataFilterContextUtil setValueMap(Map<String,String> value){
DataFilterContext dataFilterContext = getDataFilterContext();
Map<String,String> valueMap = dataFilterContext.getValueMap();
valueMap.putAll(value);
return this;
}
public DataFilterContextUtil setWhereSqlFormat(String whereSqlFormat){
DataFilterContext dataFilterContext = getDataFilterContext();
dataFilterContext.setWhereSqlFormat(whereSqlFormat);
return this;
}
public DataFilterContextUtil setSelectFormat(String selectFormat){
DataFilterContext dataFilterContext = getDataFilterContext();
dataFilterContext.setSelectFormat(selectFormat);
return this;
}
public DataFilterContextUtil setAdminMode(Boolean adminMode) {
DataFilterContext dataFilterContext = getDataFilterContext();
dataFilterContext.setAdminMode(adminMode);
return this;
}
public DataFilterContextUtil setIsAdmin(Boolean admin) {
DataFilterContext dataFilterContext = getDataFilterContext();
dataFilterContext.setIsAdmin(admin);
return this;
}
public DataFilterContext getDataFilterContext(){
DataFilterContext dataFilterContext = threadLocal.get();
if (dataFilterContext == null){
init();
dataFilterContext = threadLocal.get();
}
return dataFilterContext;
}
public boolean isDataFilterContext(){
DataFilterContext dataFilterContext = threadLocal.get();
return dataFilterContext!=null;
}
public void clear(){
threadLocal.remove();
}
/**
* 上线文模型
*/
@Data
public static class DataFilterContext {
public DataFilterContext(SysCodeSqlModeEnum mode){
this.mode = mode;
valueMap = new HashMap<>();
}
public DataFilterContext(){
valueMap = new HashMap<>();
}
/**
* 生成模式
*/
private SysCodeSqlModeEnum mode;
/**
* 别名
*/
private String aliaName;
/**
* 是否开启管理员模式
*/
private Boolean adminMode;
/**
* 是否是管理员
*/
private Boolean isAdmin;
/**
* 值
*/
private Map<String,String> valueMap;
/**
* where sql 模板
*/
String whereSqlFormat = "";
/**
* 查询模式的查询模板
*/
String selectFormat = "";
}
public enum SysCodeSqlModeEnum{
/**
* 修改where 条件
*/
WHERE,
/**
* 生成子查询模式
*/
SELECT;
}
}
MybatisDataFilterProperties 配置文件
package mybatis.properties;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
@Component
@ConfigurationProperties(prefix = "mybatis.data-filter")
@Data
public class MybatisDataFilterProperties {
/**
* where sql 模板
*/
private String whereSqlFormat;
/**
* 查询模式的查询模板
*/
private String selectFormat;
private String mode;
/**
* 是否开启管理员模式
*/
private Boolean adminMode;
}
AbstractAopPointcut 抽象的spring aop拦截器。拦截@dataFilter 注解
package mybatis.aop;
import mybatis.annotation.DataFilter;
import mybatis.util.ClassUtil;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.reflect.MethodSignature;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
/**
* 抽象aop的@Around操作,切面类直接继承即可使用
*
*/
public abstract class AbstractAopPointcut {
public abstract void pointcut();
@Around("pointcut() && @annotation(dataFilter)")
public Object around(ProceedingJoinPoint pjp, DataFilter dataFilter) throws Throwable {
PointParam pointParam = buildPointParam(pjp);
pointParam.setMatchAnnotation(dataFilter);
preHandle(pointParam);
Object returnStr = null;
try {
if (pointParam.getArguments() == null) {
returnStr = pjp.proceed();
} else {
returnStr = pjp.proceed(pointParam.getArguments());
}
afterHandle(pointParam);
} catch (Exception e) {
returnStr = exceptionHandle(pointParam, e);
if (returnStr == null){
throw e;
}
}
return returnStr;
}
/**
* 前置处理
*
* @param pointParam
*/
public void preHandle(PointParam pointParam) {
}
/**
* 后置处理
*
* @param pointParam
*/
public void afterHandle(PointParam pointParam) {
}
/**
* 异常处理
*
* @param pointParam
* @param e
*/
public Object exceptionHandle(PointParam pointParam, Exception e) {
return null;
}
/**
* 构建PointParam
*
* @param pjp
* @return
* @throws NoSuchMethodException
*/
private PointParam buildPointParam(ProceedingJoinPoint pjp) throws NoSuchMethodException {
PointParam pointParam = new PointParam();
//类
Class pointClass = pjp.getTarget().getClass();
pointParam.setPointClass(pointClass);
pointParam.setPointObject(pjp.getTarget());
List<Annotation> classAnnotations = ClassUtil.getAnnotations(pointClass);
pointParam.setClassAnnotations(classAnnotations);
//设置参数
setMethodParam(pjp, pointClass, pointParam);
return pointParam;
}
private void setMethodParam(ProceedingJoinPoint pjp, Class pointClass, PointParam pointParam) throws NoSuchMethodException {
Signature signature = pjp.getSignature();
String methodName = signature.getName();
//切面方法
Method pointMethod = null;
if (signature instanceof MethodSignature) {
MethodSignature methodSignature = (MethodSignature) signature;
pointMethod = methodSignature.getMethod();
}
Object[] args = pjp.getArgs();
if (args != null && args.length > 0) {
if (pointMethod == null) {
Class[] argsClass = new Class[args.length];
for (int i = 0; i < args.length; i++) {
Object arg = args[i];
if (arg != null) {
argsClass[i] = arg.getClass();
}
}
pointMethod = pointClass.getMethod(methodName, argsClass);
}
pointParam.setArguments(args);
}
{
if (pointMethod == null) {
Method[] methods = pointClass.getMethods();
for (Method method : methods) {
if (methodName.equals(method.getName())) {
pointMethod = method;
}
}
}
}
Annotation[] methpdAnnotations = pointMethod.getAnnotations();
//方法注解
List<Annotation> methodAnnotations = Arrays.asList(methpdAnnotations);
pointParam.setPointMethod(pointMethod);
pointParam.setMethodAnnotations(methodAnnotations);
}
}
DataFilterAopPointcut 继承AbstractAopPointcut 的数据过滤aop拦截器。设置@dataFilter的配置到DataFilterContextUtil 中
package mybatis.aop;
import mybatis.annotation.DataFilter;
import mybatis.util.DataFilterContextUtil;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.stereotype.Component;
import java.util.Map;
@Component
@Aspect
@ConditionalOnMissingBean(AbstractAopPointcut.class)
public class DataFilterAopPointcut extends AbstractAopPointcut{
public Logger logger = LoggerFactory.getLogger("mybatisDataFilter");
@Autowired(required = false)
DataFilterValueHandle dataFilterValueHandle;
@Autowired
private DataFilterContextUtil dataFilterContextUtil;
@Override
@Pointcut(value = "@annotation(com.sf.mybatis.annotation.DataFilter)")
public void pointcut() {
}
@Override
public void preHandle(PointParam pointParam) {
DataFilter dataFilter = (DataFilter) pointParam.getMatchAnnotation();
Map<String,String> valueMap = dataFilterValueHandle.valueMap(pointParam);
boolean isAdmin = dataFilterValueHandle.isAdmin(pointParam);
dataFilterContextUtil.setValueMap(valueMap);
dataFilterContextUtil.setAliaName(dataFilter.aliaName());
dataFilterContextUtil.setIsAdmin(isAdmin);
if (StringUtils.isNotEmpty(dataFilter.adminMode())){
dataFilterContextUtil.setAdminMode(new Boolean(dataFilter.adminMode()));
}
if (StringUtils.isNotEmpty(dataFilter.mode())){
dataFilterContextUtil.setMode(dataFilter.mode());
}
if (StringUtils.isNotEmpty(dataFilter.whereSqlFormat())){
dataFilterContextUtil.setWhereSqlFormat(dataFilter.whereSqlFormat());
}
if (StringUtils.isNotEmpty(dataFilter.selectFormat())){
dataFilterContextUtil.setSelectFormat(dataFilter.selectFormat());
}
}
@Override
public void afterHandle(PointParam pointParam) {
dataFilterContextUtil.clear();
}
@Override
public Object exceptionHandle(PointParam pointParam, Exception e) {
dataFilterContextUtil.clear();
return null;
}
public DataFilterValueHandle getDataFilterValueHandle() {
return dataFilterValueHandle;
}
public void setDataFilterValueHandle(DataFilterValueHandle dataFilterValueHandle) {
this.dataFilterValueHandle = dataFilterValueHandle;
}
}
@DataFilter
package mybatis.annotation;
import mybatis.util.DataFilterContextUtil;
import java.lang.annotation.*;
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE,ElementType.METHOD})
public @interface DataFilter {
/**
* WHERE 直接修改where条件,根据配置的whereSqlFormat 进行模板替换。例如:WHERE SYS_CODE IN (${SYS_CODE}) and aa = ${aa} 替换 SELECT A,B,C FROM TT WHERE SYS_CODE IN (1,2,3) and aa = aa
*
* select 把查询作为子查询处理,根据模板进行替换。例如:SELECT * FROM (${SQL}) t where ${CONDITION} 替换SELECT * FROM (SELECT A,B,C FROM TT) t WHERE SYS_CODE IN (1,2,3) and aa = aa
* 该模式要求原select 字段中包含sys_code
*
* see DataFilterContextUtil.SysCodeSqlModeEnum
* @return
*/
String mode() default "";
/**
* 别名
*/
String aliaName() default "";
/**
* where sql 模板 参考模板:WHERE SYS_CODE IN (${SYS_CODE})
*/
String whereSqlFormat() default "";
/**
* 查询模式的查询模板 参考模板:SELECT * FROM (${SQL}) t where ${CONDITION}
*/
String selectFormat() default "";
/**
* 是否开启管理员模式。如果是管理员则不会进行过滤
*/
String adminMode() default "";
}
PointParam
package mybatis.aop;
import lombok.Data;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.List;
@Data
public class PointParam {
/**
* 类
*/
private Class pointClass;
/**
* 类 对象
*/
private Object pointObject;
/**
* 类注解
*/
private List<Annotation> classAnnotations;
/**
* 方法
*/
private Method pointMethod;
/**
* 方法注解
*/
private List<Annotation> methodAnnotations;
/**
* 参数
*/
private Object[] arguments;
/**
* 返回对象
*/
private Object returnObj;
/**
* 匹配的注解
*/
private Annotation matchAnnotation;
}
ClassUtil
package mybatis.util;
import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class ClassUtil {
public static List<Annotation> getAnnotations(Class aClass){
Annotation[] annotations = aClass.getAnnotations();
return Arrays.asList(annotations);
}
public static List<Class<? extends Annotation>> getAnnotationsClass(List<Annotation> annotations){
List<Class<? extends Annotation>> list = new ArrayList<>();
for (Annotation annotation : annotations){
list.add(annotation.getClass());
}
return list;
}
public static Annotation getAnnotation(Class aClass,Class<? extends Annotation> matchAnClass){
List<Annotation> annotations = getAnnotations(aClass);
for (Annotation an:annotations){
if (matchAnClass.equals(an.getClass())){
return an;
}
}
return null;
}
/**
* 匹配annotation
* @param list
* @param matchAnClass
* @return
*/
public static Annotation matchAnnotation(List<Annotation> list,Class<? extends Annotation> matchAnClass){
for (Annotation an:list){
if (an.getClass().equals(matchAnClass)){
return an;
}
}
return null;
}
}
DataFilterValueHandle 业务系统集成设置模板替换的值
package mybatis.aop;
import java.util.Map;
public interface DataFilterValueHandle {
/**
* 模板替换值
* @param pointParam
* @return
*/
Map<String,String> valueMap(PointParam pointParam);
/**
* 是否是管理员
* @param pointParam
* @return
*/
boolean isAdmin(PointParam pointParam);
}
spring.factories
org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
com.sf.mybatis.properties.MybatisDataFilterProperties,\
com.sf.mybatis.interceptor.DataFilterSqlInterceptor,\
com.sf.mybatis.util.DataFilterContextUtil,\
com.sf.mybatis.aop.DataFilterAopPointcut