通过拦截器在插入和查询的时候添加tenantId
mybatis拦截器执行顺序
通过调试源码可以看到 xml的sql解析过程在Executor 之后,ParameterHandler之前。所以如果xml有如下形式的写法的时,并且没有使用Executor而只使用ParameterHandler,那么在解析xml时就会解析不到设置的参数
insert into column_show_info
<trim prefix="(" suffix=")" suffixOverrides=",">
<if test="tenantId != null">
tenant_id,
</if>
</trim>
如果在解析sql之后再设置tenantId已经不起作用了,if判断已经过滤掉tenantId字段。 ParameterHandler在解析xml的sql之后执行 Executor 在解析xml的sql之前执行
后面我再出一篇mybatis源码解析的文章。详细讨论mybatis的执行流程和sql的生命周期
下面是insert的拦截写法
@Slf4j
@Intercepts(
{
@Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
@Component
public class InsertInterceptor implements Interceptor {
// private static final String TENANTID = "tenantId";
@Override
public Object intercept(Invocation invocation) throws Throwable {
String tenantId = (String) ThreadLocalUtil.get(ThreadLocalUtil.TENANT_ID);
if(StringUtils.isEmpty(tenantId)){
return invocation.proceed();
}
if(invocation.getTarget() instanceof ParameterHandler){
ParameterHandler parameterHandler = (ParameterHandler) invocation.getTarget();
if(!this.checkInsert(parameterHandler)) {
return invocation.proceed();
}
Object parameterObject = parameterHandler.getParameterObject();
return this.dealTenantId(invocation,parameterObject,tenantId);
}
else {
Object[] args = invocation.getArgs();
Object parameterObject = args[1];
return this.dealTenantId(invocation,parameterObject,tenantId);
}
}
private Object dealTenantId(Invocation invocation,Object parameterObject,String tenantId) throws Throwable {
//参数是map
if(parameterObject instanceof Map){
//paramMap.get获取不到对象抛异常
try {
List<Object> list = (List<Object>) ((Map)parameterObject).get("list");
if(!CollectionUtils.isEmpty(list)){
for(Object o:list){
Field ageField1 =o.getClass().getDeclaredField(ThreadLocalUtil.TENANT_ID);
ageField1.setAccessible(true);
if(!StringUtils.hasText((CharSequence)ageField1.get(o))) {
ageField1.set(o, tenantId);
}
}
}
}catch (Exception ignored){
log.error("InsertInterceptor.intercept,获取list对象报错");
}finally {
((Map)parameterObject).put(ThreadLocalUtil.TENANT_ID,tenantId);
}
//参数是对象
}else {
try {
Field ageField = AnnotationUtil.getField(parameterObject.getClass(),ThreadLocalUtil.TENANT_ID);
if(Objects.nonNull(ageField)){
ageField.setAccessible(true);
if(!StringUtils.hasText((CharSequence) ageField.get(parameterObject))){
ageField.set(parameterObject, tenantId);
}
}
}catch (Exception ignored){
log.error("InsertInterceptor.intercept,获取parameterObject对象报错");
}
}
return invocation.proceed();
}
private boolean checkInsert(ParameterHandler parameterHandler) throws Throwable {
Field boundSqlF = parameterHandler.getClass().getDeclaredField("boundSql");
boundSqlF.setAccessible(true);
BoundSql boundSql = (BoundSql)boundSqlF.get(parameterHandler);
String sql = boundSql.getSql();
String sqlType = sql.substring(0,6);
return "INSERT".equals(sqlType) || "insert".equals(sqlType);
}
}
在查询的时候也一样需要把参数加进去,下面的查询的写法
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
@Slf4j
@Component
public class QueryInterceptor implements Interceptor {
// private static final String TENANTID = "tenantId";
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
String tenantId = (String) ThreadLocalUtil.get(ThreadLocalUtil.TENANT_ID);
if (!StringUtils.hasText(tenantId)) {
return invocation.proceed();
}
Executor executor = (Executor) invocation.getTarget();
RowBounds rowBounds = (RowBounds) args[2];
ResultHandler<?> resultHandler = (ResultHandler) args[3];
BoundSql boundSql;
CacheKey cacheKey;
//参数是map
if(parameterObject instanceof Map){
//paramMap.get获取不到对象抛异常
try {
List<Object> list = (List<Object>) ((Map)parameterObject).get("list");
if(!CollectionUtils.isEmpty(list)){
for(Object o:list){
Field ageField1 =o.getClass().getDeclaredField(ThreadLocalUtil.TENANT_ID);
ageField1.setAccessible(true);
ageField1.set(o, tenantId);
}
}
}catch (Exception ignored){
log.error("QueryInterceptor.intercept.获取list对象异常-->{}",ms.getId());
}finally {
((Map)parameterObject).put(ThreadLocalUtil.TENANT_ID,tenantId);
}
//参数是QueryConditionBase子类对象,分页的查询会经过这里
}else if(parameterObject instanceof QueryConditionBase){
Field[] declaredFields = parameterObject.getClass().getDeclaredFields();
Field[] superDeclaredFields = parameterObject.getClass().getSuperclass().getDeclaredFields();
MapperMethod.ParamMap<Object> paramMap = new MapperMethod.ParamMap<>();
for(Field field:declaredFields){
field.setAccessible(true);
paramMap.put(field.getName(),field.get(parameterObject));
}
for(Field field:superDeclaredFields){
field.setAccessible(true);
paramMap.put(field.getName(),field.get(parameterObject));
}
paramMap.put(ThreadLocalUtil.TENANT_ID,tenantId);
args[1] = paramMap;
//其他查询 分页的select(*)不会经过
}else{
if(args.length ==4 ){
boundSql = ms.getBoundSql(parameterObject);
cacheKey = executor.createCacheKey(ms,parameterObject,rowBounds,boundSql);
}else {
cacheKey = (CacheKey) args[4];
boundSql = (BoundSql) args[5];
}
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
String sql = boundSql.getSql();
List<Object> resultList = null;
try {
String newSql = "select * from (" + sql + ") as result where result.tenant_id= '"+tenantId +"'";
field.set(boundSql,newSql);
assert boundSql.getSql()!=null;
log.info("sql---------------:"+boundSql.getSql());
resultList = executor.query(ms, parameterObject, rowBounds, resultHandler, cacheKey, boundSql);
}catch (Exception e){
field.set(boundSql,sql);
resultList = executor.query(ms, parameterObject, rowBounds, resultHandler, cacheKey, boundSql);
}
return resultList;
}
log.info("执行的查询mapper---------------------:"+ms.getId());
return invocation.proceed();
}
}