一.背景
在很多业务场景下我们需要去拦截sql,达到不入侵原有代码业务处理一些东西,比如:分页操作,数据权限过滤操作,SQL执行时间性能监控等等,这里我们就可以用到Mybatis的拦截器Interceptor
二.Mybatis核心对象介绍
从MyBatis代码实现的角度来看,MyBatis的主要的核心部件有以下几个:
Configuration 初始化基础配置,比如MyBatis的别名等,一些重要的类型对象,如,插件,映射器,ObjectFactory和typeHandler对象,MyBatis所有的配置信息都维持在Configuration对象之中
SqlSessionFactory SqlSession工厂
SqlSession 作为MyBatis工作的主要顶层API,表示和数据库交互的会话,完成必要数据库增删改查功能
Executor MyBatis执行器,是MyBatis 调度的核心,负责SQL语句的生成和查询缓存的维护
StatementHandler 封装了JDBC Statement操作,负责对JDBC statement 的操作,如设置参数、将Statement结果集转换成List集合。
ParameterHandler 负责对用户传递的参数转换成JDBC Statement 所需要的参数,
ResultSetHandler 负责将JDBC返回的ResultSet结果集对象转换成List类型的集合;
TypeHandler 负责java数据类型和jdbc数据类型之间的映射和转换
MappedStatement MappedStatement维护了一条<select|update|delete|insert>节点的封装,
SqlSource 负责根据用户传递的parameterObject,动态地生成SQL语句,将信息封装到BoundSql对象中,并返回
BoundSql 表示动态生成的SQL语句以及相应的参数信息
三. Mybatis执行概要图
四.实现如下:
package com.aimin.dal.interceptor;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.parser.ParserException;
import com.alibaba.druid.util.JdbcConstants;
import com.enn.common.util.OperatorInfoUtil;
import com.enn.common.vo.UserInfo;
import com.enn.dal.config.MapperConfig;
import com.enn.dal.constants.MybatisConstants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import com.google.common.cache.Cache;
import java.sql.Connection;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* @author zlq
* @date 2021/11/11
* @description: Mybatis拦截器插件,拦截查询语句
**/
@Slf4j
@Intercepts({
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Component
public class MybatisInterceptor implements Interceptor {
private static final String COLUMN_ENT_ID = "ent_id";
private static final Pattern PATTERN_CONDITION_TENANT_ID = Pattern.compile(COLUMN_ENT_ID + " *=", Pattern.DOTALL);
private static final Pattern PATTERN_MAIN_TABLE_ALIAS = Pattern.compile("(?i).*from\\s+[a-z_]+\\s+(?:as\\s+)?((?!(left|right|full|inner|join|where|group|order|limit))[a-z]+).*", Pattern.DOTALL);
private static final Pattern PATTERN_CT_TABLE_SQL = Pattern.compile("(?i).*from *car_", Pattern.DOTALL);
private static final Pattern PATTERN_COUNT_SQL = Pattern.compile("(?i) *select *count\\(.*\\) *from *\\((.*)\\) table_count$", Pattern.DOTALL);
@Autowired
private Cache<String, MapperConfig> mybatisInterceptorCache;
@Value(MybatisConstants.INCLUDED_MAPPER_IDS)
private Set<String> includedMapperIds;
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
/*
从RoutingStatementHandler中获得处理对象PreparedStatementHandler,从这个对象中获取Mapper中的xml信息
* */
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
return invocation.proceed();
}
//配置文件中sql语句的id
String sqlId = mappedStatement.getId();
log.info("sqlId:{}", sqlId);
// 跳过无需拦截的sql
if (CollectionUtils.isNotEmpty(includedMapperIds) && !includedMapperIds.contains(sqlId.substring(0, sqlId.lastIndexOf('.')))) {
return invocation.proceed();
}
BoundSql boundSql = statementHandler.getBoundSql();
String originalSql = boundSql.getSql().trim();
if (StringUtils.isBlank(originalSql)) {
return invocation.proceed();
}
MapperConfig mapperConfig = getMapperConfig(sqlId, originalSql);
// 判断是否为业务表
if (!mapperConfig.isCtTable()) {
return invocation.proceed();
}
// 判断sql中是否已经添加了租户查询条件
if (mapperConfig.isTenantIdCondition()) {
return invocation.proceed();
}
//拼接租户相关信息
final UserInfo userInfo = OperatorInfoUtil.getUserInfo();
if (userInfo == null) {
return invocation.proceed();
}
if (!StringUtils.isAlphanumeric(userInfo.getEntId())) {
log.warn("非法的租户id或者租户id为空,entId:{},未能拼接租户查询条件", userInfo.getEntId());
return invocation.proceed();
}
log.info("entId:{}", userInfo.getEntId());
StringBuilder scopeCondition = new StringBuilder();
if (StringUtils.isNotBlank(mapperConfig.getMainTableAlias())) {
scopeCondition.append(mapperConfig.getMainTableAlias()).append(".");
}
scopeCondition.append(COLUMN_ENT_ID);
scopeCondition.append(MybatisConstants.EQUAL);
scopeCondition.append(MybatisConstants.SINGLE_QUOTATION_MARK);
scopeCondition.append(userInfo.getEntId());
scopeCondition.append(MybatisConstants.SINGLE_QUOTATION_MARK);
String sql = addCondition(originalSql, scopeCondition.toString());
try {
SQLUtils.formatMySql(sql);
originalSql = sql;
} catch (ParserException e) {
log.warn("动态添加SQL数据过滤条件失败:{}", sql);
}
log.info("sql:{}", originalSql);
metaObject.setValue("delegate.boundSql.sql", originalSql);
return invocation.proceed();
}
private MapperConfig getMapperConfig(String sqlId, String originalSql) {
MapperConfig mapperConfig = mybatisInterceptorCache.getIfPresent(sqlId);
if (mapperConfig == null) {
mapperConfig = new MapperConfig();
mapperConfig.setMainTableAlias(getMainTableAlias(originalSql));
mapperConfig.setCtTable(isCtTable(originalSql));
mapperConfig.setTenantIdCondition(hasTenantIdCondition(originalSql));
mybatisInterceptorCache.put(sqlId, mapperConfig);
}
return mapperConfig;
}
private static String addCondition(String originalSql, String scopeCondition) {
String notCountSql;
if ((notCountSql = getNotCountSql(originalSql)) != null) {
notCountSql = SQLUtils.addCondition(notCountSql, scopeCondition, JdbcConstants.MYSQL);
return "select count(0) from(" + notCountSql + ") table_count";
}
return SQLUtils.addCondition(originalSql, scopeCondition, JdbcConstants.MYSQL);
}
private static boolean hasTenantIdCondition(String sql) {
sql = removeLinefeed(sql);
Matcher matcher = PATTERN_CONDITION_TENANT_ID.matcher(sql);
return matcher.find();
}
/**
* 获取主表别名
*
* @param sql sql语句
* @return 主表别名
*/
private static String getMainTableAlias(String sql) {
sql = removeLinefeed(sql);
String notCountSql = getNotCountSql(sql);
if (notCountSql != null) {
sql = notCountSql;
}
sql = removeStrInBrackets(sql);
Matcher matcher = PATTERN_MAIN_TABLE_ALIAS.matcher(sql);
if (matcher.find()) {
return matcher.group(1);
}
return null;
}
private static boolean isCtTable(String sql) {
sql = removeLinefeed(sql);
Matcher matcher = PATTERN_CT_TABLE_SQL.matcher(sql);
return matcher.find();
}
private static String getNotCountSql(String sql) {
sql = removeLinefeed(sql);
Matcher matcher = PATTERN_COUNT_SQL.matcher(sql);
return matcher.find() ? matcher.group(1) : null;
}
private static String removeLinefeed(String str) {
return str == null ? null : str.replaceAll(MybatisConstants.SEP_LINEFEED, StringUtils.EMPTY);
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
private static String removeStrInBrackets(String str) {
if (str == null) {
return null;
}
char[] arr = str.toCharArray();
boolean left = false;
int leftIndex = 0;
int rightIndex = 0;
int num = 0;
for (int i = 0; i < arr.length; i++) {
switch (arr[i]) {
case '(':
if (!left) {
left = true;
leftIndex = i;
}
num++;
break;
case ')':
num--;
break;
default:
}
if (left && num == 0) {
rightIndex = i;
break;
}
}
if (leftIndex < rightIndex) {
str = str.substring(0, leftIndex) + str.substring(rightIndex + 1, arr.length);
if (str.contains(MybatisConstants.LEFT_PARENTHESIS_EN)) {
return removeStrInBrackets(str);
} else {
return str;
}
}
return str;
}
}