mybatis通过sql拦截实现权限校验
- 自定义权限拦截注解
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface CompanyDataScope {
String value();
}
-
定义拦截注册
import com.github.pagehelper.autoconfigure.PageHelperAutoConfiguration; import com.yan.authority.permission.MybatisDataPermissionIntercept; import org.apache.ibatis.session.SqlSessionFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.AutoConfigureAfter; import org.springframework.context.annotation.Configuration; import javax.annotation.PostConstruct; import java.util.List; @AutoConfigureAfter(PageHelperAutoConfiguration.class) //引入这块可以和PageHelper分页兼容 @Configuration public class MybatisInterceptorAutoConfiguration implements InitializingBean { @Autowired private List<SqlSessionFactory> sqlSessionFactoryList; @Override @PostConstruct public void afterPropertiesSet() { MybatisDataPermissionIntercept mybatisInterceptor = new MybatisDataPermissionIntercept(); for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) { org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration(); //自己添加 configuration.addInterceptor(mybatisInterceptor); } } }
-
定义具体拦截器
import com.yan.authority.annotation.CompanyDataScope; import com.yan.core.constant.SecurityConstants; import com.yan.core.domain.LoginUserEntity; import com.yan.module.contex.SecurityContextHolder; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.HexValue; import net.sf.jsqlparser.parser.CCJSqlParserManager; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SetOperationList; import org.apache.commons.lang3.StringUtils; import org.apache.ibatis.cache.CacheKey; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.plugin.*; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import java.io.StringReader; import java.lang.reflect.Method; import java.util.List; import java.util.Objects; import java.util.Properties; @Intercepts( { @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}), } ) @Slf4j public class MybatisDataPermissionIntercept implements Interceptor { CCJSqlParserManager parserManager = new CCJSqlParserManager(); @Override public Object intercept(Invocation invocation) throws Throwable { Object[] args = invocation.getArgs(); MappedStatement ms = (MappedStatement) args[0]; Object parameter = args[1]; RowBounds rowBounds = (RowBounds) args[2]; ResultHandler resultHandler = (ResultHandler) args[3]; Executor executor = (Executor) invocation.getTarget(); CacheKey cacheKey; BoundSql boundSql; String id = ms.getId(); //获取mapper名称 String className = id.substring(0, id.lastIndexOf(".")); //获取方法名 String methodName = id.substring(id.lastIndexOf(".") + 1); //获取当前mapper 的方法 Method[] methods = Class.forName(className).getMethods(); boolean havaAnnno=false; //通过对所有方法遍历,核实当前方法是否有自定义注解,只有自定义注解的方法才可以实现拦截 for (Method m : methods) { if (Objects.equals(m.getName(), methodName)){ CompanyDataScope annotation = m.getAnnotation(CompanyDataScope.class); if (Objects.nonNull(annotation)) { havaAnnno=true; break; } } } //由于逻辑关系,只会进入一次 if (args.length == 4) { //4 个参数时 boundSql = ms.getBoundSql(parameter); cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql); } else { //6 个参数时 cacheKey = (CacheKey) args[4]; boundSql = (BoundSql) args[5]; } if(!havaAnnno){ return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql); } //TODO 自己要进行的各种处理 String sql = boundSql.getSql(); log.info("原始SQL: {}", sql); // 这块是自己写的方法拦截,通过用户中的部门标识进行数据拦截,用户会保存于redis中,可以实现服务器共享,这块大家可以按照自己的业务逻辑方式写 LoginUserEntity loginUserEntity = SecurityContextHolder.get(SecurityConstants.LOGIN_USER, LoginUserEntity.class); if (Objects.nonNull(loginUserEntity)){ // 增强sql Select select = (Select) parserManager.parse(new StringReader(sql)); SelectBody selectBody = select.getSelectBody(); if (selectBody instanceof PlainSelect) { this.setWhere((PlainSelect) selectBody, loginUserEntity); } else if (selectBody instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectBody; List<SelectBody> selectBodyList = setOperationList.getSelects(); selectBodyList.forEach((s) -> this.setWhere((PlainSelect) s, loginUserEntity)); } String dataPermissionSql = select.toString(); log.info("增强SQL: {}", dataPermissionSql); BoundSql dataPermissionBoundSql = new BoundSql(ms.getConfiguration(), dataPermissionSql, boundSql.getParameterMappings(), parameter); return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, dataPermissionBoundSql); }else { throw new RuntimeException("获取用户信息失败,请联系管理员"); } } protected void setWhere(PlainSelect plainSelect, LoginUserEntity loginUserEntity) { Expression sqlSegment = this.getSqlSegment(plainSelect.getWhere(), loginUserEntity); if (null != sqlSegment) { plainSelect.setWhere(sqlSegment); } } @SneakyThrows public Expression getSqlSegment(Expression where, LoginUserEntity loginUse) { String dataPerm = loginUse.getDataPerm(); if (StringUtils.isEmpty(dataPerm)){ return where; } StringBuilder sqlString = new StringBuilder(); checkDataFlag(sqlString,loginUse.getDataFlag(),dataPerm); if (StringUtils.isNotBlank(sqlString.toString())) { if (where == null){ where = new HexValue(" 1 = 1 "); } //判断是不是分页, 分页完成之后 清除权限标识 return CCJSqlParserUtil.parseCondExpression(where + sqlString.toString()); }else { return where; } } private void checkDataFlag(StringBuilder sb, String dataFlag, String pjSql) { switch (dataFlag){ case "1": sb.append(" AND (create_by IN "); sb.append(pjSql); sb.append(" OR update_by IN "); sb.append(pjSql); sb.append(" )"); break; case "2": sb.append(" AND create_by IN "); sb.append(pjSql); break; case "3": sb.append(" AND update_by IN "); sb.append(pjSql); break; default: break; } } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { } }