一、使用背景
使用场景:在很多的saas系统中,对于不同的客户可能有不同的数据库,比如客户A:A数据库,客户B:B数据库,在同一个后台服务的时候,想要让两个客户都同时使用,必须在sql中的表名前加上:`${数据库名称}`来区分查的哪一个库,所以每个接口在service,dao中都要传一个相同的参数:”数据库名称“,这样做其实也行,但是可以更好的去达到我们想要的,下面就是对于这种情况的一个小小的优化--sql拦截。
二、环境配置
1、idea+springboot
2、maven引用:
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid-spring-boot-starter</artifactId>
<version>1.1.10</version>
</dependency>
三、创建注解
1、注解属性配置
/**
* @Author WEI
* @Date 2022/8/31 16:47
* @Describe sql拦截注解,作用域class或方法上,生命周期为RUNTIME
*/
@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DatabaseSchool {
}
2、注解实现方法,原理是,先找到使用该注解的类路径,通过反射先判断该class上是否呗加了注解,如果没有在判断该方法上是否加了注解,如果有该注解则对该sql进行处理操作。注:如果注解加在了方法上,尽量避免该class下有相同名称的方法(如果有两个相同名称的方法,一个有注释,一个没有注释,再sql拦截器中的反射放获取是否有注解则会不准确)
package com.wei.config;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.stat.TableStat;
import com.wei.annotations.DatabaseSchool;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.util.CollectionUtils;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
/**
* @Author WEI
* @Date 2022/8/31 16:32
* @Describe 使用了反射的方法获取是否使用了注解,所以尽量避免在同一个class下有相同名称的方法
*/
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class SqlSchoolInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
try {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
// 不存在注解时,直接执行
if (!hasAnnotation(mappedStatement.getId())){
invocation.proceed();
}
BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
metaStatementHandler.setValue("delegate.boundSql.sql", fillSql(boundSql.getSql()));
return invocation.proceed();
} catch (Exception e) {
return invocation.proceed();
}
}
@Override
public Object plugin(Object o) {
if (o instanceof StatementHandler) {
return Plugin.wrap(o, this);
}
return o;
}
@Override
public void setProperties(Properties properties) {
}
/**
* 判断是否有此注解
* @param classPath
* @return
* @throws ClassNotFoundException
*/
private boolean hasAnnotation(String classPath) throws ClassNotFoundException {
if(StringUtils.isBlank(classPath)){
return false;
}
//先判断类上面是否加上了注解
String mapperClass = classPath.substring(0, classPath.lastIndexOf("."));
Class<?> classType = Class.forName(mapperClass);
if(classType.isAnnotationPresent(DatabaseSchool.class)){
return true;
}
//如果类上面没有加上注解,判断方法上是否加上了注解
String methodName = classPath.substring(classPath.lastIndexOf(".") + 1);
Method method1 = Arrays.stream(classType.getMethods()).filter(method -> method.equals(methodName)).findFirst().orElse(null);
if(method1 == null){
return false;
}
return method1.isAnnotationPresent(DatabaseSchool.class);
}
/**
* 填充sql
* @param sql
* @return
*/
private String fillSql(String sql){
if(StringUtils.isBlank(sql)){
return sql;
}
SQLStatementParser parser = new MySqlStatementParser(sql);
// 使用Parser解析生成AST,这里SQLStatement就是AST
SQLStatement sqlStatement = parser.parseStatement();
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
sqlStatement.accept(visitor);
Map<TableStat.Name, TableStat> tables = visitor.getTables();
if (MapUtils.isEmpty(tables)) {
return sql;
}
List<String> tableNames = tables.keySet().stream().map(TableStat.Name::getName)
.filter(name -> StringUtils.isNotBlank(name) && !name.contains(".")).collect(Collectors.toList());
if (CollectionUtils.isEmpty(tableNames)) {
return sql;
}
for (String name : tableNames) {
sql = sql.replaceAll(name, "`" + UserInfo.get().getSchoolNum() + "`." + name);
}
return sql;
}
}
四、使用演示
package com.wei.dao;
import com.wei.annotations.DatabaseSchool;
import com.wei.entity.User;
import org.apache.ibatis.annotations.Param;
/**
* @Author WEI
* @Date 2022/8/31 14:37
* @Describe 注解放到了class上,所有的方法都生效
*/
@DatabaseSchool
public interface TestDao {
/**
* 更具用户明和密码获取用户
* @param userName
* @param passWord
* @return
*/
User getUser(@Param("userName") String userName,@Param("passWord") String passWord);
}
这时候的getUser()方法,则会被拦截器拦截并作相应操作后再去执行修改后的sql。