假设原始SQL如下:
SELECT id,ip,port,host_addr,slave_addr,type,value,unit,create_time,update_time,del_flag,device_id,merchant_id,tenant_id FROM xxx_value
WHERE (del_flag = ?)
下面是改写的拦截器代码:
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.stereotype.Component;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
/**
* @author hsj
* @description:数据权限dataScope的一种实现方式:mybatis拦截器改写SQL
* @date 2021/10/8 14:18
*/
@Intercepts({
@Signature(
type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}
)
})
@Slf4j
@Component
public class MyOwnBatisInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler,
SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY,
new DefaultReflectorFactory());
//先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
//sql语句类型 select、delete、insert、update
String sqlCommandType = mappedStatement.getSqlCommandType().toString();
if (!"SELECT".equals(sqlCommandType)){
//只是拦截:SELECT
return invocation.proceed();
}
//id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser
String id = mappedStatement.getId();
log.info("拦截到当前请求方法的全路径名为--->: " + id);
//获取到原始sql语句
BoundSql boundSql = statementHandler.getBoundSql();
String sql = boundSql.getSql();
//获取参数
Object parameter = statementHandler.getParameterHandler().getParameterObject();
log.info("拦截到当前请求SQL为--->: " + sql);
log.info("拦截到当前请求类型为--->: " + sqlCommandType);
log.info("拦截到当前请求参数为--->: " + parameter);
//模拟dataScope的数据权限ids
List<Long> ids = new ArrayList<Long>();
ids.add(1442769582176501761L);
//where条件
String whereSql = " where scope.%s in (%s) ";
//数据过滤
sql = String.format(" select %s from (%s) scope " + whereSql, new Object[]{"*", sql, "id", StringUtils.join(ids,",")});
log.info("拦截到当前请求SQL改写之后为--->: " + sql);
//反射改写新的SQL
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, sql);
// 执行完上面的任务后,不改变原有的sql执行过程
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {}
}
改写之后SQL:
select * from (SELECT id,ip,port,host_addr,slave_addr,type,value,unit,create_time,update_time,del_flag,device_id,merchant_id,tenant_id FROM xxx_value
WHERE (del_flag = ?)) scope where scope.id in (1442769582176501761)