MyBatis插件原理
调用流程
1.myBatis中的代理类Plugin,实现了InvocationHandler接口。
2.当调用ParameterHandler,ResultSetHandler,StatementHandler,Executor的对象的时候,就会执行Plugin的invoke方法。
3.Plugin在invoke方法中根据@Intercepts的配置信息(方法名,参数等)动态判断是否需要拦截该方法,然后将需要拦截的方法Method封装成Invocation,并调用Interceptor的proceed方法。
Executor.Method->Plugin.invoke->Interceptor.intercept
->Invocation.proceed->method.invoke
Plugin类中的invoke方法:
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
if (methods != null && methods.contains(method)) {
return interceptor.intercept(new Invocation(target, method, args));
}
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
MyBaits中被拦截的接口
1. MyBatis的执行器,用于执行增删改查操作
Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
2. 处理SQL的参数对象
ParameterHandler (getParameterObject, setParameters)
3. 处理SQL的返回结果集
ResultSetHandler (handleResultSets, handleOutputParameters)
4. 拦截Sql语法构建的处理
StatementHandler (prepare, parameterize, batch, update, query)
MyBatis的执行流程
源码分析
plugin包中的类
拦截器链InterceptorChain
public class InterceptorChain {
private final List<Interceptor> interceptors = new ArrayList<Interceptor>();
public Object pluginAll(Object target) {
// 循环调用,责任链模式
for (Interceptor interceptor : interceptors) {
target = interceptor.plugin(target);
}
return target;
}
public void addInterceptor(Interceptor interceptor) {
interceptors.add(interceptor);
}
public List<Interceptor> getInterceptors() {
return Collections.unmodifiableList(interceptors);
}
}
配置类中注入拦截器
// 处理SQL的参数对象中处理
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
return parameterHandler;
}
// 处理SQL的返回结果集中处理
public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
ResultHandler resultHandler, BoundSql boundSql) {
ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
return resultSetHandler;
}
// 拦截Sql语法构建的处理中处理
public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
return statementHandler;
}
// MyBatis的执行器中处理
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
executorType = executorType == null ? defaultExecutorType : executorType;
executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
Executor executor;
if (ExecutorType.BATCH == executorType) {
executor = new BatchExecutor(this, transaction);
} else if (ExecutorType.REUSE == executorType) {
executor = new ReuseExecutor(this, transaction);
} else {
executor = new SimpleExecutor(this, transaction);
}
if (cacheEnabled) {
executor = new CachingExecutor(executor);
}
executor = (Executor) interceptorChain.pluginAll(executor);
return executor;
}
Plugin代理类
public static Object wrap(Object target, Interceptor interceptor) {
// 获取Interceptor实现类上面的注解,判断要拦截的是哪个类(Executor,ParameterHandler,ResultSetHandler,StatementHandler)的哪个方法
Map<Class>, Set<Method>> signatureMap = getSignatureMap(interceptor);
Class> type = target.getClass();
// 取出要拦截的类以及方法
Class>[] interfaces = getAllInterfaces(type, signatureMap);
// 有配置拦截器的话就生成代理类处理
if (interfaces.length > 0) {
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(target, interceptor, signatureMap));
}
return target;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
// 判断方法是否要拦截
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
if (methods != null && methods.contains(method)) {
// 调用Interceptor.intercept,也就是执行自定义的逻辑
return interceptor.intercept(new Invocation(target, method, args));
}
// 执行原来逻辑
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
Interceptor接口
public interface Interceptor {
// 拦截
Object intercept(Invocation invocation) throws Throwable;
// 插入
Object plugin(Object target);
// 设置属性(扩展)
void setProperties(Properties properties);
}
利用MyBaits插件进行一次sql重写
业务场景
业务中需要对数据进行权限划分,于是对数据库中的所有表增加创建人,操作人字段等字段,考虑到历史数据问题,先对表进行是否有字段的判断,没有先添加字段,再执行公共数据的填充。
Plugin拦截器代码
@Slf4j
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
public class SqlPlugin implements Interceptor {
private static MyCommonMapper myCommonMapper;
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 解决循环依赖问题,利用spring上下文注入mapper
if (myCommonMapper == null) {
myCommonMapper = SpringContextUtil.getBean("myCommonMapper");
}
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
// 判断sql类型
SqlCommandType commandType = ms.getSqlCommandType();
if (! (SqlCommandType.INSERT.equals(commandType) || SqlCommandType.UPDATE.equals(commandType) || SqlCommandType.DELETE.equals(commandType))) {
return invocation.proceed();
}
// 获取sql
Object parameterObject = args[1];
BoundSql boundSql = ms.getBoundSql(parameterObject);
String sql = boundSql.getSql();
if (StringUtils.isBlank(sql)) {
return invocation.proceed();
}
// 如果为更新操作,则添加更新人数据,如果表中没有该字段则alter表增加updater_id字段后再添加数据
if (sql.indexOf("UPDATE") != -1 && sql.indexOf("WHERE") != -1) {
if (sql.indexOf("updater_id") == -1) {
String[] strings = sql.split("WHERE");
StringBuilder stringBuilders = new StringBuilder();
stringBuilders.append(strings[0]);
String table = strings[0].split("SET")[0].split("UPDATE")[1];
if (StringUtils.isNotBlank(table)) {
table = table.trim();
Boolean canAdd = false;
// 判断数据库表中是否已经有该字段,如果不想每次都判断,可以放入redis缓存中
if (StringUtils.isNotBlank(myCommonMapper.hadColumn(table, "updater_id", null))) {
canAdd = true;
} else {
try {
myCommonMapper.addColumn(table, "updater_id", "BIGINT(20)", "更新人");
canAdd = true;
} catch (Exception e) {
log.error("表{}增加字段失败{}", table, e.getMessage());
}
}
if (canAdd) {
Long userId = LOCAL_USER.get() != null ? LOCAL_USER.get().getId() : 0L;
// 重写sql
stringBuilders.append(" , updater_id = ").append(userId);
stringBuilders.append(" WHERE ").append(strings[1]);
// 包装sql后,重置到invocation中
resetSql2Invocation(invocation, String.valueOf(stringBuilders));
}
}
}
}
// 返回,继续执行
return invocation.proceed();
}
@Override
public Object plugin(Object obj) {
return Plugin.wrap(obj, this);
}
@Override
public void setProperties(Properties arg0) {
}
/**
* 包装sql后,重置到invocation中
* @param invocation
* @param sql
* @throws SQLException
*/
private void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
final Object[] args = invocation.getArgs();
MappedStatement statement = (MappedStatement) args[0];
Object parameterObject = args[1];
BoundSql boundSql = statement.getBoundSql(parameterObject);
MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
msObject.setValue("sqlSource.boundSql.sql", sql);
args[0] = newStatement;
}
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder =
new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 定义一个内部辅助类,作用是包装sql
*/
class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
注入拦截器即可使用
@Configuration
public class MybatisPluginConfig {
@Bean
public SqlPlugin sqlPlugin(){
return new SqlPlugin();
}
}