项目中用到是shardingsphere版本是4.1.1,该版本不能支持 UNION ALL这种操作,但是SQL拆分太麻烦,于是通过mybatis发拦截器解决,通过自定义注解,找到分表,提前把SQL中的逻辑表更换为对应的分表。
定义自定义注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ShardScopeAnnotation {
/**
* 逻辑表
* @return
*/
String logicTables() default "";
}
定义拦截器
@Component
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class MybatisShardingPlugin implements Interceptor {
private final String TABLE_SPLIT_SYMBOL = "_";
@Override
public Object intercept(Invocation invocation) throws Throwable {
if (invocation.getTarget() instanceof RoutingStatementHandler){
//1.获取原始sql语句
RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();
StatementHandler delegate = (StatementHandler)ReflectHelper.getFieldValue(statementHandler, "delegate");
//2.根据方法名反射,判断是否包含@NoNeedOffice注解,来决定是否需要修改sql语句
MappedStatement mappedStatement = (MappedStatement) ReflectHelper.getFieldValue(delegate, "mappedStatement");
String sqlId = mappedStatement.getId();
BoundSql boundSql = delegate.getBoundSql();
// 获取传入的参数
Object parameterMappings = statementHandler.getBoundSql().getParameterObject();
ShardingBaseEntity entity = getShardingParam(parameterMappings);
if (hasShardScopeAnnotation(sqlId) && entity != null && entity.getDay() != null){
modifySql(boundSql, mappedStatement, entity);
}
}
return invocation.proceed();
}
private void modifySql(BoundSql boundSql, MappedStatement mappedStatement, ShardingBaseEntity baseEntity){
ShardScopeAnnotation annotation = getFieldAnnotation(mappedStatement);
String logicTables = annotation.logicTables();
String sql = boundSql.getSql();
String[] logicTableArr = logicTables.split(",");
if (logicTableArr.length == 0){
return;
}
for (String actualTable : logicTableArr) {
actualTable = actualTable.trim().toLowerCase();
ShardingTableCacheEnum logicTable = ShardingTableCacheEnum.of(actualTable);
String resultTableName = actualTable +TABLE_SPLIT_SYMBOL+ ShardingUtil.getYearMonth(baseEntity.getAttendanceDay());
if (logicTable.resultTableNamesCache().contains(resultTableName)){
sql = sql.replaceAll("\\b"+actualTable+"\\b", resultTableName);
}
}
ReflectHelper.setFieldValue(boundSql, "sql", sql);
}
private ShardingBaseEntity getShardingParam(Object paramObject){
if (paramObject == null){
return null;
}
ShardingBaseEntity shardingBaseEntity = null;
if (paramObject instanceof Map){
Map<String, Object> params = (Map<String, Object>) paramObject;
for (Map.Entry<String, Object> entry : params.entrySet()) {
if (entry.getValue() instanceof ShardingBaseEntity) {
return (ShardingBaseEntity) entry.getValue();
}
}
}else if (paramObject instanceof ShardingBaseEntity){
shardingBaseEntity = (ShardingBaseEntity) paramObject;
}
return shardingBaseEntity;
}
/**
* 判断注解中是否包含 自定义的注解
* @param classAnnotations
* @return
*/
private boolean containsShardScopeAnnotation(Annotation[] classAnnotations) {
for (Annotation annotation : classAnnotations) {
if (annotation instanceof ShardScopeAnnotation) {
return true;
}
}
return false;
}
private boolean hasShardScopeAnnotation(String sqlId){
//1.得到类路径和方法路径
int lastIndexOfDot = sqlId.lastIndexOf(".");
String className = sqlId.substring(0, lastIndexOfDot);
String methodName = sqlId.substring(lastIndexOfDot + 1);
//2.得到类上的注解
Class<?> clazz = null;
try {
clazz = Class.forName(className);
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
Annotation[] classAnnotations = clazz.getAnnotations();
if (containsShardScopeAnnotation(classAnnotations)) {
return true;
}
//3.得到方法上的注解
Method[] methods = clazz.getMethods();
for (Method method : methods) {
String name = method.getName();
if (methodName.equals(name)) {
Annotation[] methodAnnotations = method.getAnnotations();
if (containsShardScopeAnnotation(methodAnnotations)) {
return true;
}
}
}
return false;
}
private ShardScopeAnnotation getFieldAnnotation(MappedStatement mappedStatement) {
ShardScopeAnnotation annotation = null;
try {
String id = mappedStatement.getId();
String className = id.substring(0, id.lastIndexOf("."));
String methodName = id.substring(id.lastIndexOf(".") + 1);
final Method[] method = Class.forName(className).getMethods();
for (Method me : method) {
if (me.getName().equals(methodName) && me.isAnnotationPresent(ShardScopeAnnotation.class)) {
return me.getAnnotation(ShardScopeAnnotation.class);
}
}
} catch (Exception ex) {
System.out.println(ex);
}
return annotation;
}
}
在dao上
@ShardScopeAnnotation(logicTables = "table1, table2")
List<User> findList2(UserBO UserBO );
这样在代码执行到拦截器的时候,根据传入的分表参数,动态的找到准确的分表,然后把SQL中的表给替换掉,shardingspere在执行的时候,会将原SQL给执行,然后输出结果。
注意,这里是单个分表的查询,不执行多个分表的查询,如果涉及到多个分表,那么需要循环执行,然后将结果合并