package com.interceptor;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.hbgg.common.exception.BusinessException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Select;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;
import java.lang.reflect.Field;
import java.util.Map;
import java.util.Properties;
@Component
@Intercepts({@Signature(
type = Executor.class,
method = "update",
args = {MappedStatement.class, Object.class}
),@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
public class SplitTableInterceptor implements Interceptor {
private final static String tableName = "tableName";
public SplitTableInterceptor() {
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement)args[0];
Object parameter = args[1];
BoundSql boundSql;
//获取配置文件中最原始的sql
boundSql = ms.getBoundSql(parameter);
if(args.length>4){
//其他拦截器处理后的sql
boundSql = (BoundSql)args[5];
}
String sql = "select 1";
Statement statement = CCJSqlParserUtil.parse(boundSql.getSql());
if((statement instanceof Select) && args.length != 4){
//查询
boundSql = (BoundSql)args[5];
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, sql);
}else {
//增删改
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, sql);
MappedStatement newStatement = newMappedStatement(ms, new BoundSqlSqlSource(boundSql));
MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
msObject.setValue("sqlSource.boundSql.sql", sql);
args[0] = newStatement;
}
return invocation.proceed();
}
private static MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder =
new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
@Override
public Object plugin(Object o) {
return Plugin.wrap(o, this);
}
@Override
public void setProperties(Properties properties) {
}
// 定义一个内部辅助类,作用是包装sq
static class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
查询语句中是通过反射修改执行的sql实现,增删改语句中是通过替换MappedStatement
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement)args[0];
Object parameter = args[1];
String sid = null;
if(parameter instanceof Map){
Map param = (Map) parameter;
if(!param.containsKey("sid")){
int size = param.keySet().size()/2;
for(int i = 1; i<= size; i++){
String key = "param"+i;
Object o = param.get(key);
if(o instanceof QueryWrapper){
QueryWrapper q = (QueryWrapper) o;
String sqlSegment = q.getExpression().getSqlSegment();
String s = "sid =";
if(sqlSegment.contains(s)){
String pKey = sqlSegment.split(s)[1].split("paramNameValuePairs.")[1].split("}")[0];
sid = q.getParamNameValuePairs().get(pKey).toString();
break;
}
}
}
}else {
sid = param.get("sid").toString();
}
}else if(parameter instanceof String){
sid = parameter.toString();
}else{
JSONObject json = (JSONObject) JSON.toJSON(parameter);
sid = json.get("sid").toString();
}
if(sid == null){
throw new Exception("获取信息失败");
}
System.out.println("sid : "+sid)
return invocation.proceed();
}