基于Mybatis 数据过滤组件(一)

基于mybatis Interceptor 机制实现的自动添加数据过滤组件。

背景

  1. 我们的系统要求所有的数据都是按照系统编码过滤,那么查询sql都需要写上SYS_CODE IN (所属系统编码)这样的sql
  2. 我们的系统都要求按创建人来过滤,那么所有的查询、修改、删除都需要加上创建人的条件
  3. 使用我们的组件就可以不需要改动sql,自动的完成条件的添加。就像是现在很流行的pagehelper 分页插件一样的简单

核心类

  • AbstractSqlInterceptor 继承mybatis的Interceptor 的抽象拦截器
  • DataFilterSqlInterceptor 继承AbstractSqlInterceptor 的数据过滤组件
  • DataFilterContextUtil 过滤配置上下文
  • MybatisDataFilterProperties 配置文件
  • AbstractAopPointcut 抽象的spring aop拦截器。拦截@dataFilter 注解
  • DataFilterAopPointcut 继承AbstractAopPointcut 的数据过滤aop拦截器。设置@dataFilter的配置到DataFilterContextUtil 中,并且调用DataFilterValueHandle 设置模板变量值
  • DataFilterValueHandle 业务系统集成设置模板替换的值

核心代码

AbstractSqlInterceptor 继承Interceptor 的抽象拦截器

package mybatis.interceptor;

import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Properties;

/**
 * 抽象sql拦截器,拦截的是Executor.class
 */
public abstract class AbstractSqlInterceptor implements Interceptor {

    public Logger logger = LoggerFactory.getLogger("mybatisDataFilter");

    /**
     * 判断是否执行 子类继承
     * @param invocation
     * @return
     */
    public abstract boolean match(Invocation invocation);

    /**
     * 构建新sql 子类继承
     * @param boundSql
     * @param origSql
     * @return
     */
    public abstract String buildSql(BoundSql boundSql,String origSql);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (!match(invocation)){
            return invocation.proceed();
        }

        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameterObject = args[1];

        // id为执行的mapper方法的全路径名,如com.mapper.UserMapper
//        String id = ms.getId();

        // sql语句类型 select、delete、insert、update
        String sqlCommandType = ms.getSqlCommandType().toString();

        // 仅拦截 select 查询
        if (!sqlCommandType.equals(SqlCommandType.SELECT.toString())) {
            return invocation.proceed();
        }
        BoundSql boundSql = ms.getBoundSql(parameterObject);
        String origSql = boundSql.getSql();
        logger.debug("原始SQL: {}", origSql);

        // 组装新的 sql
        String newSql = buildSql(boundSql,origSql);

        // 重新new一个查询语句对象
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
                boundSql.getParameterMappings(), boundSql.getParameterObject());

        // 把新的查询放到statement里
        MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        Object[] queryArgs = invocation.getArgs();
        queryArgs[0] = newMs;

        logger.debug("改写的SQL: {}", newSql);

        return invocation.proceed();
    }


    /**
     * 定义一个内部辅助类,作用是包装 SQL
     */
    class BoundSqlSqlSource implements SqlSource {
        private BoundSql boundSql;
        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }

    }

    private MappedStatement newMappedStatement (MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new
                MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }
}

DataFilterSqlInterceptor 继承AbstractSqlInterceptor 的数据过滤组件

package mybatis.interceptor;

import com.alibaba.fastjson.JSON;
import mybatis.properties.MybatisDataFilterProperties;
import mybatis.util.DataFilterContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.Map;


/**
 * 系统编码sql自动生成拦截器
 *
 * 配置优先级:DataFilter ->DataFilterContextUtil -> MybatisDataFilterProperties
 */
@Intercepts({@Signature(type = Executor.class, method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class,ResultHandler.class})})
@Component
@Slf4j
public class DataFilterSqlInterceptor extends AbstractSqlInterceptor{

    @Autowired
    private DataFilterContextUtil dataFilterContextUtil;

    private final static String WHERE_SQL = "WHERE SYS_CODE IN (${SYS_CODE})";

    private String whereSqlFormat;

    private final static String AND = "AND";

    private final static String SELECT_FORMAT = "SELECT * FROM (${SQL}) t WHERE ${CONDITION}";

    private String selectFormat;

    private final static String POINT = ".";
    //检查过滤key
    private final static String CHECK_KEY[] = {"ORDER","LIMIT"};

    private final static String WHERE = "WHERE";

    private final static String ALIAS_NAME = "ALIAS_NAME";

    private final static String ALIAS_NAME_FORMAT = "${ALIAS_NAME}";

    @Autowired
    MybatisDataFilterProperties mybatisDataFilterProperties;


    @Override
    public boolean match(Invocation invocation) {
        if (!dataFilterContextUtil.isDataFilterContext()){
            return false;
        }

        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();

        Boolean isAdmin = dataFilterContext.getIsAdmin() == null ? false : dataFilterContext.getIsAdmin();
        //是管理员
        if (getAdminMode() && isAdmin){
            logger.debug("----------admin is not filter");
            return false;
        }
        return true;
    }

    @Override
    public String buildSql(BoundSql boundSql, String origSql) {

        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();

        logger.debug("开始构建sql mod:{}",getMode());
//        List<UserSystemDetailDTO> userList = (List<UserSystemDetailDTO>) ApplicationContextUtil.getValue(ApplicationContextUtil.USER_SYS_CODE);
//        String sysCodes = ZkSysCodeFactory.getUserListStr(userList);
        //转换大写
        String upperSql = origSql.toUpperCase();
        String newSql = origSql;
        Map<String,String> valueMap = dataFilterContext.getValueMap();
        logger.debug("buildSql valueMap:{} format:{}",JSON.toJSONString(valueMap),dataFilterContextUtil.getDataFilterContext().getWhereSqlFormat());
        if (valueMap == null || valueMap.isEmpty()){
            logger.warn("valueMap is null 不进行sql处理");
            return "";
        }

        if (DataFilterContextUtil.SysCodeSqlModeEnum.WHERE.equals(getMode())){
            newSql = buildWhereSql(upperSql,valueMap);
        }else if (DataFilterContextUtil.SysCodeSqlModeEnum.SELECT.equals(getMode())){
            newSql = buildSelectSql(upperSql,valueMap);
        }else {
            throw new RuntimeException("mode is not WHERE OR SELECT");
        }
        return newSql;
    }

    /**
     * select模式的sql生成
     *
     * 生成select * ( 原sql) t where $whereCondition
     * @param origSql
     * @param valueMap
     * @return
     */
    private String buildSelectSql(String origSql,Map<String,String> valueMap){
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        dataFilterContext.setAliaName("t");
        //条件
        String whereCondition = buildWhereCondition(valueMap);

        logger.debug("改写的whereCondition:"+whereCondition);

        String selFormat = getSelectFormat();
        //去除where
        selFormat = selFormat.replace(WHERE,"");

        selFormat = selFormat.replace(WHERE.toLowerCase(),"");

        if (whereCondition.indexOf(WHERE) == -1 && whereCondition.indexOf(WHERE.toLowerCase()) == -1){
            whereCondition = WHERE+" "+whereCondition;
        }

        //select模板数据
        Map<String,String> selectValueMap = new HashMap<>();
        selectValueMap.put("SQL",origSql);
        selectValueMap.put("CONDITION",whereCondition);
//        return String.format(SELECT_FORMAT,origSql)+""+whereCondition;

        return format(selFormat,selectValueMap);
    }

    /**
     * 将 where 条件替换成 WHERE SYS_CODE IN () 格式
     * 如果没有where 则新生成一个where
     * @param origSql
     * @return
     */
    private String buildWhereSql(String origSql,Map<String,String> valueMap){
        //where 条件
        String whereCondition = buildWhereCondition(valueMap);

        logger.debug("改写的whereCondition:"+whereCondition);

        if (whereCondition.indexOf(WHERE) == -1 && whereCondition.indexOf(WHERE.toLowerCase()) == -1){
            whereCondition = WHERE+" "+whereCondition;
        }

        if (origSql.contains(WHERE)){
            whereCondition = whereCondition+" "+AND+" ";
            //替换where
            return origSql.replaceFirst(WHERE,whereCondition);
        }else {
            //添加where
            return addWhereCondition(origSql,whereCondition);
        }
    }

    /**
     * 构建where 条件
     * @return
     */
    private String buildWhereCondition(Map<String,String> valueMap){
        if (valueMap == null || valueMap.isEmpty()){
            return "";
        }
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        String aliaName = dataFilterContext.getAliaName();
        if (StringUtils.isNotEmpty(aliaName)){
            valueMap.put(ALIAS_NAME,aliaName);
        }
        logger.debug("buildWhereCondition valueMap:{} format:{}",JSON.toJSONString(valueMap),dataFilterContextUtil.getDataFilterContext().getWhereSqlFormat());
        return format(getWhereSqlFormat(),valueMap);
    }

    /**
     * 格式化
     * @param format
     * @param valueMap
     * @return
     */
    private String format(String format,Map<String,String> valueMap){
        StringSubstitutor sub = new StringSubstitutor(valueMap);
        String formatSql = sub.replace(format);
        return formatSql;
    }

    /**
     * 添加where条件
     * select * from tt order/limit  -> select * from tt $whereCondition order/limit
     * @param origSql
     * @param whereCondition
     * @return
     */
    public String addWhereCondition(String origSql,String whereCondition){
        //如果包含了CHECK_KEY的字段,则where条件在其之前
        for (String key : CHECK_KEY){
            int index = origSql.indexOf(key);
            if (index != -1){
                String subStr = origSql.substring(0,index);
                String lastSubStr = origSql.substring(index,origSql.length());
                return subStr+" "+whereCondition+lastSubStr;
            }
        }
        return origSql + " "+whereCondition;
    }

    public String getWhereSqlFormat() {
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        whereSqlFormat = dataFilterContext.getWhereSqlFormat();
        if (StringUtils.isNotEmpty(whereSqlFormat)){
            String whereSql = aliasNameHandle(whereSqlFormat);
            logger.debug("whereSqlFormat sql:{}",whereSqlFormat);
            return whereSql;
        }

        whereSqlFormat = mybatisDataFilterProperties.getWhereSqlFormat();
        logger.debug("whereSqlFormat:{}",whereSqlFormat);
        if (StringUtils.isNotEmpty(whereSqlFormat)){
            String whereSql = aliasNameHandle(whereSqlFormat);
            logger.debug("whereSqlFormat sql:{}",whereSqlFormat);
            return whereSql;
        }
        return WHERE_SQL;
    }

    private String aliasNameHandle(String whereSql){
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        if (whereSql != null && StringUtils.isEmpty(dataFilterContext.getAliaName())){
            return whereSql.replace(ALIAS_NAME_FORMAT+POINT,"");
        }
        return whereSql;
    }

    public DataFilterContextUtil.SysCodeSqlModeEnum getMode(){
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        DataFilterContextUtil.SysCodeSqlModeEnum sysCodeSqlModeEnum = dataFilterContext.getMode();
        if (sysCodeSqlModeEnum == null){
            String globleMode = mybatisDataFilterProperties.getMode();
            if (StringUtils.isNotEmpty(globleMode)){
                dataFilterContextUtil.setMode(globleMode);
                sysCodeSqlModeEnum = dataFilterContext.getMode();
            }
        }

        return sysCodeSqlModeEnum == null ? DataFilterContextUtil.SysCodeSqlModeEnum.WHERE : sysCodeSqlModeEnum;
    }

    public String getSelectFormat() {
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        selectFormat = dataFilterContext.getSelectFormat();
        if (StringUtils.isNotEmpty(selectFormat)){
            return selectFormat;
        }

        selectFormat = mybatisDataFilterProperties.getSelectFormat();
        if (StringUtils.isNotEmpty(selectFormat)){
            return selectFormat;
        }
        return SELECT_FORMAT;
    }

    public boolean getAdminMode(){
        DataFilterContextUtil.DataFilterContext dataFilterContext = dataFilterContextUtil.getDataFilterContext();
        Boolean adminMode = dataFilterContext.getAdminMode();
        if (adminMode == null){
            adminMode = mybatisDataFilterProperties.getAdminMode();
            if (adminMode != null){
                dataFilterContextUtil.setAdminMode(adminMode);
                adminMode = dataFilterContext.getAdminMode();
            }
        }
        return adminMode == null ? true : adminMode;
    }

    /*public static void main(String[] args) {
        String origSql = "SELECT * FROM AA Inner join b on AA.A=B.B ORDER AA.KEY DESC Limit 10";
        for (String key : CHECK_KEY){
            int index = origSql.indexOf(key);
            if (index != -1){
                String subStr = origSql.substring(0,index);
                String lastSubStr = origSql.substring(index,origSql.length());
                System.out.println(subStr);
                System.out.println(lastSubStr);
            }
        }
    }*/

    public DataFilterContextUtil getDataFilterContextUtil() {
        return dataFilterContextUtil;
    }

    public void setDataFilterContextUtil(DataFilterContextUtil dataFilterContextUtil) {
        this.dataFilterContextUtil = dataFilterContextUtil;
    }

    public static String getWhereSql() {
        return WHERE_SQL;
    }

    public void setWhereSqlFormat(String whereSqlFormat) {
        this.whereSqlFormat = whereSqlFormat;
    }


    public void setSelectFormat(String selectFormat) {
        this.selectFormat = selectFormat;
    }

    public MybatisDataFilterProperties getMybatisDataFilterProperties() {
        return mybatisDataFilterProperties;
    }

    public void setMybatisDataFilterProperties(MybatisDataFilterProperties mybatisDataFilterProperties) {
        this.mybatisDataFilterProperties = mybatisDataFilterProperties;
    }
}

DataFilterContextUtil 过滤配置上下文

package mybatis.util;

import lombok.Data;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.Map;

@Component
public class DataFilterContextUtil {

    private static ThreadLocal<DataFilterContext> threadLocal = new ThreadLocal();

    public DataFilterContextUtil init(){
        DataFilterContext dataFilterContext = new DataFilterContext();
        threadLocal.set(dataFilterContext);
        return this;
    }

    public DataFilterContextUtil setMode(SysCodeSqlModeEnum mode){
        DataFilterContext dataFilterContext = getDataFilterContext();
        dataFilterContext.setMode(mode);
        return this;
    }

    public DataFilterContextUtil setMode(String mode){
        DataFilterContext dataFilterContext = getDataFilterContext();
        SysCodeSqlModeEnum sysCodeSqlModeEnum =SysCodeSqlModeEnum.valueOf(mode.toUpperCase());
        dataFilterContext.setMode(sysCodeSqlModeEnum);
        return this;
    }

    public DataFilterContextUtil setAliaName(String aliaName){
        DataFilterContext dataFilterContext = getDataFilterContext();
        dataFilterContext.setAliaName(aliaName);
        return this;
    }

    public DataFilterContextUtil setValue(String name,String value){
        DataFilterContext dataFilterContext = getDataFilterContext();
        Map<String,String> valueMap = dataFilterContext.getValueMap();
        valueMap.put(name,value);
        return this;
    }

    public DataFilterContextUtil setValueMap(Map<String,String> value){
        DataFilterContext dataFilterContext = getDataFilterContext();
        Map<String,String> valueMap = dataFilterContext.getValueMap();
        valueMap.putAll(value);
        return this;
    }

    public DataFilterContextUtil setWhereSqlFormat(String whereSqlFormat){
        DataFilterContext dataFilterContext = getDataFilterContext();
        dataFilterContext.setWhereSqlFormat(whereSqlFormat);
        return this;
    }

    public DataFilterContextUtil setSelectFormat(String selectFormat){
        DataFilterContext dataFilterContext = getDataFilterContext();
        dataFilterContext.setSelectFormat(selectFormat);
        return this;
    }

    public DataFilterContextUtil setAdminMode(Boolean adminMode) {
        DataFilterContext dataFilterContext = getDataFilterContext();
        dataFilterContext.setAdminMode(adminMode);
        return this;
    }

    public DataFilterContextUtil setIsAdmin(Boolean admin) {
        DataFilterContext dataFilterContext = getDataFilterContext();
        dataFilterContext.setIsAdmin(admin);
        return this;
    }

    public DataFilterContext getDataFilterContext(){
        DataFilterContext dataFilterContext = threadLocal.get();
        if (dataFilterContext == null){
            init();
            dataFilterContext = threadLocal.get();
        }
        return dataFilterContext;
    }

    public boolean isDataFilterContext(){
        DataFilterContext dataFilterContext = threadLocal.get();
        return dataFilterContext!=null;
    }

    public  void clear(){
        threadLocal.remove();
    }


    /**
     * 上线文模型
     */
    @Data
    public static class DataFilterContext {
        public DataFilterContext(SysCodeSqlModeEnum mode){
            this.mode = mode;
            valueMap = new HashMap<>();
        }

        public DataFilterContext(){
            valueMap = new HashMap<>();
        }

        /**
         * 生成模式
         */
        private SysCodeSqlModeEnum mode;
        /**
         * 别名
         */
        private String aliaName;

        /**
         * 是否开启管理员模式
         */
        private Boolean adminMode;

        /**
         * 是否是管理员
         */
        private Boolean isAdmin;

        /**
         * 值
         */
        private Map<String,String> valueMap;

        /**
         * where sql 模板
         */
        String whereSqlFormat = "";

        /**
         * 查询模式的查询模板
         */
        String selectFormat = "";

    }

    public enum SysCodeSqlModeEnum{
        /**
         * 修改where 条件
         */
        WHERE,
        /**
         * 生成子查询模式
         */
        SELECT;

    }
}

MybatisDataFilterProperties 配置文件

package mybatis.properties;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@Component
@ConfigurationProperties(prefix = "mybatis.data-filter")
@Data
public class MybatisDataFilterProperties {

    /**
     * where sql 模板
     */
    private String whereSqlFormat;

    /**
     * 查询模式的查询模板
     */
    private String selectFormat;

    private String mode;

    /**
     * 是否开启管理员模式
     */
    private Boolean adminMode;

}

AbstractAopPointcut 抽象的spring aop拦截器。拦截@dataFilter 注解

package mybatis.aop;

import mybatis.annotation.DataFilter;
import mybatis.util.ClassUtil;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.reflect.MethodSignature;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;

/**
 * 抽象aop的@Around操作,切面类直接继承即可使用
 *
 */
public abstract class AbstractAopPointcut {

    public abstract void pointcut();

    @Around("pointcut() && @annotation(dataFilter)")
    public Object around(ProceedingJoinPoint pjp, DataFilter dataFilter) throws Throwable {
        PointParam pointParam = buildPointParam(pjp);
        pointParam.setMatchAnnotation(dataFilter);
        preHandle(pointParam);
        Object returnStr = null;
        try {
            if (pointParam.getArguments() == null) {
                returnStr = pjp.proceed();
            } else {
                returnStr = pjp.proceed(pointParam.getArguments());
            }
            afterHandle(pointParam);
        } catch (Exception e) {
            returnStr = exceptionHandle(pointParam, e);
            if (returnStr == null){
                throw e;
            }
        }
        return returnStr;
    }

    /**
     * 前置处理
     *
     * @param pointParam
     */
    public void preHandle(PointParam pointParam) {
    }

    /**
     * 后置处理
     *
     * @param pointParam
     */
    public void afterHandle(PointParam pointParam) {
    }

    /**
     * 异常处理
     *
     * @param pointParam
     * @param e
     */
    public Object exceptionHandle(PointParam pointParam, Exception e) {
        return null;
    }

    /**
     * 构建PointParam
     *
     * @param pjp
     * @return
     * @throws NoSuchMethodException
     */
    private PointParam buildPointParam(ProceedingJoinPoint pjp) throws NoSuchMethodException {
        PointParam pointParam = new PointParam();
        //类
        Class pointClass = pjp.getTarget().getClass();
        pointParam.setPointClass(pointClass);
        pointParam.setPointObject(pjp.getTarget());

        List<Annotation> classAnnotations = ClassUtil.getAnnotations(pointClass);
        pointParam.setClassAnnotations(classAnnotations);
        //设置参数
        setMethodParam(pjp, pointClass, pointParam);

        return pointParam;
    }

    private void setMethodParam(ProceedingJoinPoint pjp, Class pointClass, PointParam pointParam) throws NoSuchMethodException {
        Signature signature = pjp.getSignature();

        String methodName = signature.getName();
        //切面方法
        Method pointMethod = null;

        if (signature instanceof MethodSignature) {
            MethodSignature methodSignature = (MethodSignature) signature;
            pointMethod = methodSignature.getMethod();
        }

        Object[] args = pjp.getArgs();
        if (args != null && args.length > 0) {
            if (pointMethod == null) {
                Class[] argsClass = new Class[args.length];
                for (int i = 0; i < args.length; i++) {
                    Object arg = args[i];
                    if (arg != null) {
                        argsClass[i] = arg.getClass();
                    }
                }

                pointMethod = pointClass.getMethod(methodName, argsClass);
            }
            pointParam.setArguments(args);

        }
        {
            if (pointMethod == null) {
                Method[] methods = pointClass.getMethods();
                for (Method method : methods) {
                    if (methodName.equals(method.getName())) {
                        pointMethod = method;
                    }
                }
            }
        }

        Annotation[] methpdAnnotations = pointMethod.getAnnotations();
        //方法注解
        List<Annotation> methodAnnotations = Arrays.asList(methpdAnnotations);
        pointParam.setPointMethod(pointMethod);

        pointParam.setMethodAnnotations(methodAnnotations);
    }

}

DataFilterAopPointcut 继承AbstractAopPointcut 的数据过滤aop拦截器。设置@dataFilter的配置到DataFilterContextUtil 中

package mybatis.aop;

import mybatis.annotation.DataFilter;
import mybatis.util.DataFilterContextUtil;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.stereotype.Component;

import java.util.Map;

@Component
@Aspect
@ConditionalOnMissingBean(AbstractAopPointcut.class)
public class DataFilterAopPointcut extends AbstractAopPointcut{

    public Logger logger = LoggerFactory.getLogger("mybatisDataFilter");

    @Autowired(required = false)
    DataFilterValueHandle dataFilterValueHandle;

    @Autowired
    private DataFilterContextUtil dataFilterContextUtil;

    @Override
    @Pointcut(value = "@annotation(com.sf.mybatis.annotation.DataFilter)")
    public void pointcut() {

    }

    @Override
    public void preHandle(PointParam pointParam) {
        DataFilter dataFilter = (DataFilter) pointParam.getMatchAnnotation();
        Map<String,String> valueMap =  dataFilterValueHandle.valueMap(pointParam);
        boolean isAdmin = dataFilterValueHandle.isAdmin(pointParam);

        dataFilterContextUtil.setValueMap(valueMap);
        dataFilterContextUtil.setAliaName(dataFilter.aliaName());
        dataFilterContextUtil.setIsAdmin(isAdmin);

        if (StringUtils.isNotEmpty(dataFilter.adminMode())){
            dataFilterContextUtil.setAdminMode(new Boolean(dataFilter.adminMode()));
        }

        if (StringUtils.isNotEmpty(dataFilter.mode())){
            dataFilterContextUtil.setMode(dataFilter.mode());
        }

        if (StringUtils.isNotEmpty(dataFilter.whereSqlFormat())){
            dataFilterContextUtil.setWhereSqlFormat(dataFilter.whereSqlFormat());
        }
        if (StringUtils.isNotEmpty(dataFilter.selectFormat())){
            dataFilterContextUtil.setSelectFormat(dataFilter.selectFormat());
        }



    }

    @Override
    public void afterHandle(PointParam pointParam) {
        dataFilterContextUtil.clear();
    }

    @Override
    public Object exceptionHandle(PointParam pointParam, Exception e) {
        dataFilterContextUtil.clear();
        return null;
    }

    public DataFilterValueHandle getDataFilterValueHandle() {
        return dataFilterValueHandle;
    }

    public void setDataFilterValueHandle(DataFilterValueHandle dataFilterValueHandle) {
        this.dataFilterValueHandle = dataFilterValueHandle;
    }
}

@DataFilter

package mybatis.annotation;

import mybatis.util.DataFilterContextUtil;

import java.lang.annotation.*;


@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE,ElementType.METHOD})
public @interface DataFilter {

    /**
     * WHERE 直接修改where条件,根据配置的whereSqlFormat 进行模板替换。例如:WHERE SYS_CODE IN (${SYS_CODE}) and aa = ${aa} 替换 SELECT A,B,C FROM TT WHERE SYS_CODE IN (1,2,3) and aa = aa
     *
     * select 把查询作为子查询处理,根据模板进行替换。例如:SELECT * FROM (${SQL}) t where ${CONDITION} 替换SELECT * FROM (SELECT A,B,C FROM TT) t  WHERE SYS_CODE IN (1,2,3) and aa = aa
     * 该模式要求原select 字段中包含sys_code
     *
     * see DataFilterContextUtil.SysCodeSqlModeEnum
     * @return
     */
    String mode() default "";
    /**
     * 别名
     */
     String aliaName() default "";

    /**
     * where sql 模板 参考模板:WHERE SYS_CODE IN (${SYS_CODE})
     */
    String whereSqlFormat() default "";

    /**
     * 查询模式的查询模板 参考模板:SELECT * FROM (${SQL}) t where ${CONDITION}
     */
    String selectFormat() default "";

    /**
     * 是否开启管理员模式。如果是管理员则不会进行过滤
     */
    String adminMode() default "";
}

PointParam

package mybatis.aop;

import lombok.Data;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.List;


@Data
public class PointParam {
    /**
     * 类
     */
    private Class pointClass;

    /**
     * 类 对象
     */
    private Object pointObject;
    /**
     * 类注解
     */
    private List<Annotation> classAnnotations;
    /**
     * 方法
     */
    private Method pointMethod;
    /**
     * 方法注解
     */
    private List<Annotation> methodAnnotations;
    /**
     * 参数
     */
    private Object[] arguments;
    /**
     * 返回对象
     */
    private Object returnObj;
    /**
     * 匹配的注解
     */
    private Annotation matchAnnotation;
}

ClassUtil

package mybatis.util;

import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


public class ClassUtil {

    public static List<Annotation> getAnnotations(Class aClass){
        Annotation[] annotations = aClass.getAnnotations();
        return Arrays.asList(annotations);
    }

    public static List<Class<? extends Annotation>> getAnnotationsClass(List<Annotation> annotations){
        List<Class<? extends Annotation>> list = new ArrayList<>();
        for (Annotation annotation : annotations){
            list.add(annotation.getClass());
        }
        return list;
    }

    public static Annotation getAnnotation(Class aClass,Class<? extends Annotation> matchAnClass){
        List<Annotation> annotations = getAnnotations(aClass);
        for (Annotation an:annotations){
            if (matchAnClass.equals(an.getClass())){
                return an;
            }
        }
        return null;
    }

    /**
     * 匹配annotation
     * @param list
     * @param matchAnClass
     * @return
     */
    public static Annotation matchAnnotation(List<Annotation> list,Class<? extends Annotation> matchAnClass){
        for (Annotation an:list){
            if (an.getClass().equals(matchAnClass)){
                return an;
            }
        }
        return null;
    }
}

DataFilterValueHandle 业务系统集成设置模板替换的值

package mybatis.aop;

import java.util.Map;

public interface DataFilterValueHandle {

    /**
     * 模板替换值
     * @param pointParam
     * @return
     */
    Map<String,String> valueMap(PointParam pointParam);

    /**
     * 是否是管理员
     * @param pointParam
     * @return
     */
    boolean isAdmin(PointParam pointParam);
}

spring.factories

org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
com.sf.mybatis.properties.MybatisDataFilterProperties,\
com.sf.mybatis.interceptor.DataFilterSqlInterceptor,\
com.sf.mybatis.util.DataFilterContextUtil,\
com.sf.mybatis.aop.DataFilterAopPointcut
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值