- 在mybatis中可被拦截的类型有四种(按照拦截顺序):
1、Executor:拦截执行器的方法。
2、ParameterHandler:拦截参数的处理。
3、ResultHandler:拦截结果集的处理。
4、StatementHandler:拦截Sql语法构建的处理。- 拦截器作用:
1、分页查询
2、多租户添加条件过滤
3、对返回结果,过滤掉审计字段,敏感字段
4、对返回结果中的加密数据进行解密
5、对新增数据自动添加创建人,创建时间,更新时间,更新人 ,对更新数据自动新增更新时间,更新人
我们用Executor来实现一下拦截器的作用一(也是做常用的拦截作用)之一:
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import utn.app.daobase.model.PageInfo;
/**
*
* 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。
* 利用拦截器实现Mybatis分页的原理:
* 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象,而且对应的Sql语句
* 是在Statement之前产生的,所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的
* prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用
* StatementHandler对象的prepare方法,即调用invocation.proceed()。
* 对于分页而言,在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设
* 置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计。
*
*/
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
@Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class MybatisSpringPageInterceptor implements Interceptor {
private static final Logger log = LoggerFactory.getLogger(MybatisSpringPageInterceptor.class);
public static final String MYSQL = "mysql";
public static final String ORACLE = "oracle";
protected String databaseType;// 数据库类型,不同的数据库有不同的分页方法
@SuppressWarnings("rawtypes")
protected ThreadLocal<PageInfo> pageThreadLocal = new ThreadLocal<PageInfo>();
public String getDatabaseType() {
return databaseType;
}
public void setDatabaseType(String databaseType) {
if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
}
this.databaseType = databaseType;
}
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
public void setProperties(Properties properties) {
String databaseType = properties.getProperty("databaseType");
if (databaseType != null) {
setDatabaseType(databaseType);
}
}
/*对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler,
BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,
SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是
处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个
StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler
PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法*/
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object intercept(Invocation invocation) throws Throwable {
if (invocation.getTarget() instanceof StatementHandler) { // 控制SQL和查询总数的地方
PageInfo page = pageThreadLocal.get();
if (page == null) { //不是分页查询
return invocation.proceed();
}
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
BoundSql boundSql = delegate.getBoundSql();
Connection connection = (Connection) invocation.getArgs()[0];
prepareAndCheckDatabaseType(connection); // 准备数据库类型
if (page.getTotalPage() > -1) {
if (log.isTraceEnabled()) {
log.trace("已经设置了总页数, 不需要再查询总数.");
}
} else {
Object parameterObj = boundSql.getParameterObject();
MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
queryTotalRecord(page, parameterObj, mappedStatement, connection);
}
String sql = boundSql.getSql();
String pageSql = buildPageSql(page, sql);
if (log.isDebugEnabled()) {
log.debug("分页时, 生成分页pageSql: " + pageSql);
}
ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
return invocation.proceed();
} else { // 查询结果的地方
// 获取是否有分页Page对象
PageInfo<?> page = findPageObject(invocation.getArgs()[1]);
if (page == null) {
if (log.isTraceEnabled()) {
log.trace("没有Page对象作为参数, 不是分页查询.");
}
return invocation.proceed();
} else {
if (log.isTraceEnabled()) {
log.trace("检测到分页Page对象, 使用分页查询.");
}
}
//设置真正的parameterObj
invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
pageThreadLocal.set(page);
try {
Object resultObj = invocation.proceed(); // Executor.query(..)
if (resultObj instanceof List) {
/* @SuppressWarnings({ "unchecked", "rawtypes" }) */
page.setResults((List) resultObj);
}
return resultObj;
} finally {
pageThreadLocal.remove();
}
}
}
protected PageInfo<?> findPageObject(Object parameterObj) {
if (parameterObj instanceof PageInfo<?>) {
return (PageInfo<?>) parameterObj;
} else if (parameterObj instanceof Map) {
for (Object val : ((Map<?, ?>) parameterObj).values()) {
if (val instanceof PageInfo<?>) {
return (PageInfo<?>) val;
}
}
}
return null;
}
/**
* <pre>
* 把真正的参数对象解析出来
* Spring会自动封装zhe这个参数对象为Map<String, Object>对象
* 对于通过@Param指定key值参数我们不做处理,因为XML文件需要该KEY值
* 而对于没有@Param指定时,Spring会使用0,1作为主键
* 对于没有@Param指定名称的参数,一般XML文件会直接对真正的参数对象解析,
* 此时解析出真正的参数作为根对象
* </pre>
* @param parameterObj
* @return
*/
protected Object extractRealParameterObject(Object parameterObj) {
if (parameterObj instanceof Map<?, ?>) {
Map<?, ?> parameterMap = (Map<?, ?>) parameterObj;
if (parameterMap.size() == 2) {
boolean springMapWithNoParamName = true;
for (Object key : parameterMap.keySet()) {
if (!(key instanceof String)) {
springMapWithNoParamName = false;
break;
}
String keyStr = (String) key;
if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
springMapWithNoParamName = false;
break;
}
}
if (springMapWithNoParamName) {
for (Object value : parameterMap.values()) {
if (!(value instanceof PageInfo<?>)) {
return value;
}
}
}
}
}
return parameterObj;
}
protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
if (databaseType == null) {
String productName = connection.getMetaData().getDatabaseProductName();
if (log.isTraceEnabled()) {
log.trace("Database productName: " + productName);
}
productName = productName.toLowerCase();
if (productName.indexOf(MYSQL) != -1) {
databaseType = MYSQL;
} else if (productName.indexOf(ORACLE) != -1) {
databaseType = ORACLE;
} else {
throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
}
if (log.isInfoEnabled()) {
log.info("自动检测到的数据库类型为: " + databaseType);
}
}
}
/**
* <pre>
* 生成分页SQL
* </pre>
*
* @param page
* @param sql
* @return
*/
protected String buildPageSql(PageInfo<?> page, String sql) {
if (MYSQL.equalsIgnoreCase(databaseType)) {
return buildMysqlPageSql(page, sql);
} else if (ORACLE.equalsIgnoreCase(databaseType)) {
return buildOraclePageSql(page, sql);
}
return sql;
}
/**
* <pre>
* 生成Mysql分页查询SQL
* </pre>
*
* @param page
* @param sql
* @return
*/
protected String buildMysqlPageSql(PageInfo<?> page, String sql) {
// 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
int offset = (page.getPageNo() - 1) * page.getPageSize();
return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString();
}
/**
* <pre>
* 生成Oracle分页查询SQL
* </pre>
*
* @param page
* @param sql
* @return
*/
protected String buildOraclePageSql(PageInfo<?> page, String sql) {
// 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
StringBuilder sb = new StringBuilder(sql);
sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
sb.insert(0, "select * from (").append(") where r >= ").append(offset);
return sb.toString();
}
/**
* <pre>
* 查询总数
* </pre>
*
* @param page
* @param parameterObject
* @param mappedStatement
* @param connection
* @throws SQLException
*/
protected void queryTotalRecord(PageInfo<?> page, Object parameterObject, MappedStatement mappedStatement, Connection connection) throws SQLException {
BoundSql boundSql = mappedStatement.getBoundSql(page);
//获取到我们自己写在Mapper映射语句中对应的Sql语句
String sql = boundSql.getSql();
//通过查询Sql语句获取到对应的计算总记录数的sql语句
String countSql = this.buildCountSql(sql);
if (log.isDebugEnabled()) {
log.debug("分页时, 生成countSql: " + countSql);
}
//通过BoundSql获取对应的参数映射
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(pstmt);
//执行获取总记录数的Sql语句和获取结果了。
rs = pstmt.executeQuery();
if (rs.next()) {
long totalRecord = rs.getLong(1);
//给当前的参数page对象设置总记录数
page.setTotalRecord(totalRecord);
}
} finally {
if (rs != null)
try {
rs.close();
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("关闭ResultSet时异常.", e);
}
}
if (pstmt != null) {
try {
pstmt.close();
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("关闭PreparedStatement时异常.", e);
}
}
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
*
* @param sql
* @return
*/
protected String buildCountSql(String sql) {
int index = sql.toLowerCase().indexOf("from");
return "select count(*) " + sql.substring(index);
}
/**
* 利用反射进行操作的一个工具类
*
*/
private static class ReflectUtil {
/**
* 利用反射获取指定对象的指定属性
*
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标属性的值
*/
public static Object getFieldValue(Object obj, String fieldName) {
Object result = null;
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
field.setAccessible(true);
try {
result = field.get(obj);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return result;
}
/**
* 利用反射获取指定对象里面的指定属性
*
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
// 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
}
}
return field;
}
/**
* 利用反射设置指定对象的指定属性为指定的值
*
* @param obj 目标对象
* @param fieldName 目标属性
* @param fieldValue 目标值
*/
public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
public static class PageNotSupportException extends RuntimeException {
/** serialVersionUID*/
private static final long serialVersionUID = 1L;
public PageNotSupportException() {
super();
}
public PageNotSupportException(String message, Throwable cause) {
super(message, cause);
}
public PageNotSupportException(String message) {
super(message);
}
public PageNotSupportException(Throwable cause) {
super(cause);
}
}
}
mybatis配置如下:
别忘了在mybatis配置文件中这个配置
<!-- 配置管理器 -->
<configuration>
<properties>
<property name="dialect" value="mysql" />
</properties>
<settings>
<setting name="logImpl" value="STDOUT_LOGGING" />
</settings>
<plugins>
<plugin interceptor="utn.app.daobase.interceptor.MybatisSpringPageInterceptor"></plugin>
</plugins>
</configuration>
我们用Executor来实现一下拦截器的作用五如下:
javapackage com.fen.dou.inceptor;
import com.fen.dou.entity.BaseEntity;
import com.fen.dou.entity.User;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import java.util.Map;
import java.util.Properties;
@Slf4j
**@Intercepts( {@Signature(type = Executor.class,method = "update",args = {MappedStatement.class,Object.class})})**
@Component
public class MyExecutor implements Interceptor {
@Override
@SuppressWarnings("unchecked")
public Object intercept(Invocation invocation) throws Throwable {
// 根据签名指定的args顺序获取具体的实现类
// 1. 获取MappedStatement实例, 并获取当前SQL命令类型
MappedStatement ms = (MappedStatement) invocation.getArgs()[0];
SqlCommandType commandType = ms.getSqlCommandType();
// 2. 获取当前正在被操作的类, 有可能是Java Bean, 也可能是普通的操作对象, 比如普通的参数传递
// 普通参数, 即是 @Param 包装或者原始 Map 对象, 普通参数会被 Mybatis 包装成 Map 对象
// 即是 org.apache.ibatis.binding.MapperMethod$ParamMap
Object parameter = invocation.getArgs()[1];
// 获取拦截器指定的方法类型, 通常需要拦截 update
String methodName = invocation.getMethod().getName();
log.info("NormalPlugin, methodName; {}, commandType: {}", methodName, commandType);
// 3. 获取当前用户信息
User user = new User(1,"yangcai","sssss");
// 默认测试参数值
int creator = 2, updater = 3;
if (parameter instanceof BaseEntity) {
// 4. 实体类
BaseEntity entity = (BaseEntity) parameter;
if (user != null) {
creator = entity.getCreator();
updater = entity.getUpdater();
}
if (methodName.equals("update")) {
if (commandType.equals(SqlCommandType.INSERT)) {
entity.setCreator(creator);
entity.setUpdater(updater);
entity.setCreateTime(System.currentTimeMillis());
entity.setUpdateTime(System.currentTimeMillis());
} else if (commandType.equals(SqlCommandType.UPDATE)) {
entity.setUpdater(updater);
entity.setUpdateTime(System.currentTimeMillis());
}
}
} else if (parameter instanceof Map) {
// 5. @Param 等包装类
// 更新时指定某些字段的最新数据值
if (commandType.equals(SqlCommandType.UPDATE)) {
// 遍历参数类型, 检查目标参数值是否存在对象中, 该方式需要应用编写有一些统一的规范
// 否则均统一为实体对象, 就免去该重复操作
Map map = (Map) parameter;
if (map.containsKey("creator")) {
map.put("creator", creator);
}
if (map.containsKey("updateTime")) {
map.put("updateTime",System.currentTimeMillis());
}
}
}
// 6. 均不是需要被拦截的类型, 不做操作
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}