最近项目中出现了因为Mybatis的动态where条件不满足导致实际sql语句的where条件为空,进而更新了全表
如何禁止这种情况,个人觉得三种措施:
● 1.在逻辑层面加充分的参数有效性检查;
● 2.在where条件中如果索引条件都不满足,加上1=2这种必然失败的条件;
● 3.Mybatis拦截器
前两种措施都是依赖人,从这个层面讲,是不靠谱的,即一个策略不是强制的,就是不靠谱的.相对而言,第三种是不依赖程序员的自觉性,是最靠谱的
MyBatis提供了一种插件(plugin)的功能,虽然叫做插件,但其实这是拦截器功能。
MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:
Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
拦截执行器的方法
ParameterHandler (getParameterObject, setParameters)
拦截参数的处理
ResultSetHandler (handleResultSets, handleOutputParameters)
拦截结果集的处理
StatementHandler (prepare, parameterize, batch, update, query)
拦截Sql语法构建的处理
实现:
package org.apache.ibatis.plugin;
import java.util.Properties;
/**
- @author Clinton Begin
*/
public interface Interceptor {
Object intercept(Invocation invocation) throws Throwable;
Object plugin(Object target);
void setProperties(Properties properties);
}
intercept:它将直接覆盖你所拦截的对象,有个参数Invocation对象,通过该对象,可以反射调度原来对象的方法;
plugin:target是被拦截的对象,它的作用是给被拦截对象生成一个代理对象;
setProperties:允许在plugin元素中配置所需参数,该方法在插件初始化的时候会被调用一次;
明确拦截器对什么方法启用
因为我们是要对sql语句进行拦截,所以我们拦截的应该是StatementHandler的prepare方法
具体代码:
@Intercepts({ @Signature(type = StatementHandler.class,
method = “prepare”,
args = { Connection.class, Integer.class }) })
@Component
public class EmptyWhereInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(EmptyWhereInterceptor.class);
/**
* 拦截的 COMMAND 类型
*/
private static final Set<String> INTERCEPTOR_COMMAND = new HashSet<String>() {{
add("update");
add("delete");
}};
@Override
public Object intercept(Invocation invocation) throws Throwable {
logger.info("----进入拦截器");
//对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler,
//BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,
StatementHandler handler = (StatementHandler) invocation.getTarget();
//Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,
// 里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler
if (handler instanceof RoutingStatementHandler) {
handler = (BaseStatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
}
//BaseStatementHandler的成员变量mappedStatement
//获取SqlCommandType
String commandType = getCommandType(handler);
if (INTERCEPTOR_COMMAND.contains(commandType)) {
//获取sql
String originSql = handler.getBoundSql().getSql().toLowerCase();
if (!originSql.contains("where")) {
logger.error("Prohibit the use of SQL statements without where conditions.originSql={}", originSql);
throw new RuntimeException("Prohibit the use of SQL statements without where conditions.originSql"+originSql);
}
}
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
/**
* 获取Command类型,小写化返回
*
* @param handler
* @return
*/
private String getCommandType(StatementHandler handler) {
MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(handler, "mappedStatement");
return mappedStatement.getSqlCommandType().toString().toLowerCase();
}
}
涉及一个反射的工具类
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
/**
-
@author lihuiyue
-
@date 2019-01-21
*/
public class ReflectUtil {
public ReflectUtil() {
}/**
- 改变 Accessible,便于访问private等属性
- @param field
*/
private static void makeAccessible(Field field) {
if(!Modifier.isPublic(field.getModifiers())) {
field.setAccessible(true);
}
}
/**
-
获取 object 的字段,字段名称为filedName,获取不到返回null
-
@param object
-
@param filedName
-
@return
*/
private static Field getDeclaredField(Object object, String filedName) {
Class superClass = object.getClass();while(superClass != Object.class) {
try {
return superClass.getDeclaredField(filedName);
} catch (NoSuchFieldException var4) {
superClass = superClass.getSuperclass();
}
}return null;
}
/**
-
获取object字段fieldName的值,如果字段不存在直接抛异常
-
@param object
-
@param fieldName
-
@return
*/
public static Object getFieldValue(Object object, String fieldName) {
Field field = getDeclaredField(object, fieldName);
if(field == null) {
throw new IllegalArgumentException(“Could not find field [” + fieldName + “] on target [” + object + “]”);
} else {
makeAccessible(field);
Object result = null;try { result = field.get(object); } catch (IllegalAccessException var5) { var5.printStackTrace(); } return result;
}
}
/**
-
设置object字段fieldName的值,如果字段不存在直接抛异常
-
@param object
-
@param fieldName
-
@param value
*/
public static void setFieldValue(Object object, String fieldName, Object value) {
Field field = getDeclaredField(object, fieldName);
if(field == null) {
throw new IllegalArgumentException(“Could not find field [” + fieldName + “] on target [” + object + “]”);
} else {
makeAccessible(field);try { field.set(object, value); } catch (IllegalAccessException var5) { var5.printStackTrace(); }
}
}
}
测试:
问题:在batch项目中这个mybatis拦截器失效
原因:因为batch 是个多数据源的项目,每个数据源我们都自定义了SqlSessionFactory,导致此拦截器没有注入。在创建SqlSessionFactory的时候,具体代码如下:
注入拦截器
@Autowired
private EmptyWhereInterceptor emptyWhereInterceptor;
SqlSessionFactoryBean中设置拦截器
sqlSessionFactoryBean.setPlugins(new Interceptor[]{emptyWhereInterceptor});
这里碰到一个坑,就是设置plugins时必须在sqlSessionFactoryBean.getObject()之前
可跟踪源码看到:
sqlSessionFactory = sqlSessionFactoryBean.getObject();
@Override
public SqlSessionFactory getObject() throws Exception {
if (this.sqlSessionFactory == null) {
afterPropertiesSet();
}
return this.sqlSessionFactory;
}
@Override
public void afterPropertiesSet() throws Exception {
notNull(dataSource, “Property ‘dataSource’ is required”);
notNull(sqlSessionFactoryBuilder, “Property ‘sqlSessionFactoryBuilder’ is required”);
state((configuration == null && configLocation == null) || !(configuration != null && configLocation != null),
“Property ‘configuration’ and ‘configLocation’ can not specified with together”);
this.sqlSessionFactory = buildSqlSessionFactory();
}
buildSqlSessionFactory()
if (!isEmpty(this.plugins)) {
for (Interceptor plugin : this.plugins) {
configuration.addInterceptor(plugin);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(“Registered plugin: '” + plugin + “’”);
}
}
}
最后贴上正确的配置代码(DataSourceSqlSessionFactory代码片段)
@Bean
public SqlSessionFactory sqlSessionFactoryV1(){
SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
sqlSessionFactoryBean.setDataSource(dataSourceV1);
try {
sqlSessionFactoryBean.setMapperLocations(new PathMatchingResourcePatternResolver().getResources(MAPPER_LOCATION));
} catch (IOException e) {
logger.error(“创建V1 SqlSessionFactory 异常”,e);
throw new RuntimeException(“创建V1 SqlSessionFactory 异常”);
}
SqlSessionFactory sqlSessionFactory = null;
try {
sqlSessionFactoryBean.setPlugins(new Interceptor[]{emptyWhereInterceptor});
sqlSessionFactory = sqlSessionFactoryBean.getObject();
} catch (Exception e) {
logger.error(“创建V1 SqlSessionFactory 异常”,e);
throw new RuntimeException(“创建V1 SqlSessionFactory 异常”);
}
return sqlSessionFactory;
}