概述
顾名思义,就是一个拦截器,和springmvc的拦截器,servlet的过滤器差不多,就是在执行前拦了一道,里面可以做一些自己的事情。
平时用的mybatisPlus
较多,直接以com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor
为例。其内部维护了一个拦截器List,在拦截的时候for循环依次去调用这些拦截器,这时候的执行顺序就是list中的元素下标;
业务上有时候需要做一些全局的权限隔离,逐行修改代码的方式有点麻烦,而且上新功能的时候还得手动加,属于重操作;这个时候可以利用mybatis
的拦截器,配合开源的JSqlParser
解析器,可以做到动态拼接条件;
mybatis plus
的租户隔离也是这样做的;
mybatis plus:3.5.2
mybatis官方介绍中可以拦截的类型共4种:
Executor
(拦截执行器的方法),method=update
包括了增删改,可以从MappedStatement
中获取实际执行的是哪种类型ParameterHandler
(拦截参数的处理)ResultSetHandler
(拦截结果集的处理)StatementHandler
(拦截Sql语法构建的处理)
平时业务中拦截较多的就是增删改查,经典的就是分页拦截器–查询拦截器。
mybatisPlus拦截器Demo
参考了mybatisPlus的com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor
写完拦截器后记得要放到mybatisPlus的拦截器集合中去。
如果要从拦截器中方便的获取参数,拦截mybatis
的mapper
方法简单一点,mybatisPlus
的lambda
方式的参数获取比较复杂,前期虽然写起来方面了,但是后期迭代要动这方面时就会很麻烦,也提醒大家要保持良好的封装,以及重要操作统一入口的习惯
package com.xxx.xxx.xxx;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
import com.baomidou.mybatisplus.core.conditions.update.Update;
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
import com.baomidou.mybatisplus.core.mapper.Mapper;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import java.lang.reflect.Field;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class MyInnerInterceptor implements InnerInterceptor {
/**
*
* @param executor Executor(可能是代理对象)
* @param ms MappedStatement
* @param parameter parameter
* @throws SQLException
*/
@Override
public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
SqlCommandType sqlCommandType = ms.getSqlCommandType();
String msId = ms.getId();
if(!Objects.equals(SqlCommandType.UPDATE, sqlCommandType)){
return;
}
//更新
if (parameter instanceof Map) {
Map<String, Object> map = (Map<String, Object>) parameter;
//被更新的实体类的对象在mybatis这里都是用et做别名
Object et = map.getOrDefault("et", null);
//对应mapper的class
final String className = getMapperClassName(msId);
//mapper使用的更新方法
final String methodName = getMapperMethodName(msId);
//实体类的class
Class<?> entityClass = getEntityClass(className);
//获取实体类的字段信息
TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
//当前实体类的属性集合
List<TableFieldInfo> fieldList = tableInfo.getFieldList();
// updateById(et), update(et, wrapper);
if(Objects.nonNull(et)){
try {
for (TableFieldInfo fieldInfo : fieldList) {
//field
Field field = fieldInfo.getField();
//获取column
String column = fieldInfo.getColumn();
//旧的value
Object oldValue = field.get(et);
}
} catch (IllegalAccessException e) {
throw ExceptionUtils.mpe(e);
}
// update(LambdaUpdateWrapper) or update(UpdateWrapper)
}else if (map.entrySet().stream().anyMatch(t -> Objects.equals(t.getKey(), "ew"))) {
Object ew = map.get("ew");
if (!(ew instanceof AbstractWrapper && ew instanceof Update)) {
return;
}
final Map<String, Object> paramNameValuePairs = ((AbstractWrapper<?, ?, ?>) ew).getParamNameValuePairs();
for (TableFieldInfo fieldInfo : fieldList) {
//field
Field field = fieldInfo.getField();
//获取column
String column = fieldInfo.getColumn();
Wrapper<?> wrapper = (Wrapper<?>) ew;
//查询更新条件中指定的列名对应的值----下面方法只能获取where条件中等于的字段值
String valueKey = getValueKey(column, wrapper);
final Object conditionValue = paramNameValuePairs.get(valueKey);
}
}else if(map.entrySet().stream().noneMatch(t -> Objects.equals(t.getKey(), "ew"))){
//监听 mapper 的update方法
//比如mapper方法是 updateStatus(@Param("id") Long id, @Param("status") String status)
//那么这里的map里会包含key = id和status对应的更新前的值,还有想要更新的值
try {
Long id = null;
Object idObj = map.get(keyProperty);
if(idObj instanceof Integer){
id = (long)(Integer)idObj;
}else {
id = (Long) idObj;
}
String state = (String) map.get("state");
if(Objects.equals(keyProperty, "targetFillRecordId")){
TargetStatusLogUtil.recordStatus(id, 1, state);
}
if(Objects.equals(keyProperty, "groupCollectId")){
TargetStatusLogUtil.recordStatus(id, 2, state);
}
if(Objects.equals(keyProperty, "leaderCollectId")){
TargetStatusLogUtil.recordStatus(id, 3, state);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
}
/**
* 查询更新条件中指定的列名对应的值
* 即查询 where xxx = value 这个条件中的 xxx 对应的 value
* @param column
* @param wrapper
* @return
*/
private String getValueKey(String column, Wrapper<?> wrapper){
Pattern pattern = Pattern.compile("#\\{ew\\.paramNameValuePairs\\.(" + "MPGENVAL" + "\\d+)\\}");
final NormalSegmentList segments = wrapper.getExpression().getNormal();
String fieldName = null;
ISqlSegment eq = null;
String valueKey = null;
for (ISqlSegment segment : segments) {
String sqlSegment = segment.getSqlSegment();
//如果字段已找到并且当前segment为EQ
if(Objects.nonNull(fieldName) && segment == SqlKeyword.EQ){
eq = segment;
//如果EQ找到并且value已找到
}else if(Objects.nonNull(fieldName) && Objects.nonNull(eq)){
Matcher matcher = pattern.matcher(sqlSegment);
if(matcher.matches()){
valueKey = matcher.group(1);
return valueKey;
}
//处理嵌套
}else if (segment instanceof Wrapper){
if(null != (valueKey = getValueKey(column, ((Wrapper<?>) segment)))){
return valueKey;
}
//判断字段是否是要查找字段
}else if(Objects.equals(column, sqlSegment)){
fieldName = sqlSegment;
}
}
return valueKey;
}
private String getMapperMethodName(String msId){
return msId.substring(msId.lastIndexOf('.') + 1);
}
private String getMapperClassName(String msId){
return msId.substring(0, msId.lastIndexOf('.'));
}
/**
* 通过mapper上实体类信息获取实体类class
* @param className
* @return
*/
private Class<?> getEntityClass(String className){
try {
return ReflectionKit.getSuperClassGenericType(Class.forName(className), Mapper.class, 0);
} catch (ClassNotFoundException e) {
throw ExceptionUtils.mpe(e);
}
}
}
注册拦截器
@Bean
public MyInnerInterceptor myInnerInterceptor(ApplicationContext applicationContext){
MyInnerInterceptor myInnerInterceptor = new MyInnerInterceptor();
MybatisPlusInterceptor bean = applicationContext.getBean(MybatisPlusInterceptor.class);
bean.addInnerInterceptor(myInnerInterceptor);
return myInnerInterceptor;
}
权限拦截器demo
搞个自定义注解,然后拦截器里判断下是否使用了注解
一般业务中已经有现成的拦截器在使用了,在原来的拦截器链中加上自己的拦截器就ok了;
可以参考mybatis plus
的租户拦截器,里面对各种子查询都有做处理
com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor
//自定义注解
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface UserDataPermission {
}
//拦截器,实际处理交给对应的handler
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.*;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.SQLException;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class MyDataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
/**
* 数据权限处理器
*/
private MyDataPermissionHandler dataPermissionHandler;
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
mpBs.sql(this.parserSingle(mpBs.sql(), ms.getId()));
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
SelectBody selectBody = select.getSelectBody();
if (selectBody instanceof PlainSelect) {
this.setWhere((PlainSelect) selectBody, (String) obj);
} else if (selectBody instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = setOperationList.getSelects();
selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
}
}
/**
* 设置 where 条件
*
* @param plainSelect 查询对象
* @param whereSegment 查询条件片段
*/
private void setWhere(PlainSelect plainSelect, String whereSegment) {
Expression sqlSegment = this.dataPermissionHandler.getSqlSegment(plainSelect, whereSegment);
if (null != sqlSegment) {
plainSelect.setWhere(sqlSegment);
}
}
}
//简单单表查询demo
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
public class MyDataPermissionHandler {
private RemoteRoleService remoteRoleService;
private RemoteUserService remoteUserService;
/**
* 获取数据权限 SQL 片段
*
* @param plainSelect 查询对象
* @param whereSegment 查询条件片段
* @return JSqlParser 条件表达式
*/
@SneakyThrows(Exception.class)
public Expression getSqlSegment(PlainSelect plainSelect, String whereSegment) {
//自定义的方法,这里用的是service,一般用token中的权限比较方便快捷
remoteRoleService = SpringUtil.getBean(RemoteRoleService.class);
remoteUserService = SpringUtil.getBean(RemoteUserService.class);
// 待执行 SQL Where 条件表达式
Expression where = plainSelect.getWhere();
if (where == null) {
where = new HexValue(" 1 = 1 ");
}
log.info("开始进行权限过滤,where: {},mappedStatementId: {}", where, whereSegment);
//获取mapper名称
String className = whereSegment.substring(0, whereSegment.lastIndexOf("."));
//获取方法名
String methodName = whereSegment.substring(whereSegment.lastIndexOf(".") + 1);
Table fromItem = (Table) plainSelect.getFromItem();
// 有别名用别名,无别名用表名,防止字段冲突报错
Alias fromItemAlias = fromItem.getAlias();
String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
//获取当前mapper 的方法
Method[] methods = Class.forName(className).getMethods();
//遍历判断mapper 的所以方法,判断方法上是否有 UserDataPermission
for (Method m : methods) {
if (Objects.equals(m.getName(), methodName)) {
UserDataPermission annotation = m.getAnnotation(UserDataPermission.class);
if (annotation == null) {
return where;
}
// 1、当前用户Code
User user = SecurityUtils.getUser();
// 2、当前角色即角色或角色类型(可能多种角色)
Set<String> roleTypeSet = remoteRoleService.currentUserRoleType();
DataScope scopeType = DataPermission.getScope(roleTypeSet);
switch (scopeType) {
// 查看全部
case ALL:
return where;
case DEPT:
// 查看本部门用户数据
// 创建IN 表达式
// 创建IN范围的元素集合
List<String> deptUserList = remoteUserService.listUserCodesByDeptCodes(user.getDeptCode());
// 把集合转变为JSQLParser需要的元素列表
ItemsList deptList = new ExpressionList(deptUserList.stream().map(StringValue::new).collect(Collectors.toList()));
InExpression inExpressiondept = new InExpression(new Column(mainTableName + ".creator_code"), deptList);
return new AndExpression(where, inExpressiondept);
case MYSELF:
// 查看自己的数据
// = 表达式
EqualsTo usesEqualsTo = new EqualsTo();
usesEqualsTo.setLeftExpression(new Column(mainTableName + ".creator_code"));
usesEqualsTo.setRightExpression(new StringValue(user.getUserCode()));
return new AndExpression(where, usesEqualsTo);
default:
break;
}
}
}
//说明无权查看,
where = new HexValue(" 1 = 2 ");
return where;
}
}
//注册拦截器
@Bean
public MyDataPermissionInterceptor myInterceptor(MybatisPlusInterceptor mybatisPlusInterceptor) {
MyDataPermissionInterceptor sql = new MyDataPermissionInterceptor();
sql.setDataPermissionHandler(new MyDataPermissionHandler());
//拦截器实际执行的时候是按照list中的顺序调用
List<InnerInterceptor> list = new ArrayList<>();
// 添加数据权限插件
list.add(sql);
// 分页插件
mybatisPlusInterceptor.setInterceptors(list);
list.add(new PaginationInnerInterceptor(DbType.MYSQL));
return sql;
}