原文地址:MyBatis拦截器实现分页
拦截器的作用就是拦截某些方法的调用,我们可以在方法执行前后为方法加上某些逻辑,也可以只执行拦截的逻辑代码而不执行被拦截的方法。Mybatis拦截器设计的一个初衷就是为了供用户在某些时候可以实现自己的逻辑而不必去动Mybatis固有的逻辑。
拦截器分页原理
Mybatis拦截器常常会被用来进行分页处理。我们知道要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,Mybatis在执行Sql语句前也会产生一个包含Sql语句的Statement对象,而且对应的Sql语句是在Statement之前产生的,所以我们就可以在它成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用StatementHandler对象的prepare方法,即调用invocation.proceed()。更改Sql语句这个看起来很简单,而事实上来说的话就没那么直观,因为包括sql等其他属性在内的多个属性都没有对应的方法可以直接取到,它们对外部都是封闭的,是对象的私有属性,所以这里就需要引入反射机制来获取或者更改对象的私有属性的值了。对于分页而言,在拦截器里面我们常常还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计
实现
分页的实体类Page:
public class Page<T> {
protected int pageNo = 1; // 当前页码
protected int pageSize = 10; // 页面大小,设置为“-1”表示不进行分页(分页无效)
protected long count;// 总记录数,设置为“-1”表示不查询总数
protected int first;// 首页索引
protected int last;// 尾页索引
protected int prev;// 上一页索引
protected int next;// 下一页索引
private boolean firstPage;//是否是第一页
private boolean lastPage;//是否是最后一页
protected int length = 8;// 显示页面长度
protected int slider = 1;// 前后显示页面长度
private List<T> list = new ArrayList<T>();
private String orderBy = ""; // 标准查询有效, 实例: updatedate desc, name asc
protected String funcName = "page"; // 设置点击页码调用的js函数名称,默认为page,在一页有多个分页对象时使用。
protected String funcParam = ""; // 函数的附加参数,第三个参数值。
private String message = ""; // 设置提示消息,显示在“共n条”之后
public Page() {
this.pageSize = -1;
}
/**
* 构造方法,
* @param pageNo 页码,
* @param orderBy 排序
*/
public Page(String pageNo,String orderBy){
this.setPageNo(Integer.parseInt(pageNo));
// 设置排序参数
if (StringUtils.isNotBlank(orderBy)){
this.setOrderBy(orderBy);
}
}
/**
* 构造方法
* @param pageNo 当前页码
* @param pageSize 分页大小
*/
public Page(int pageNo, int pageSize) {
this(pageNo, pageSize, 0);
}
/**
* 构造方法
* @param pageNo 当前页码
* @param pageSize 分页大小
* @param count 数据条数
*/
public Page(int pageNo, int pageSize, long count) {
this(pageNo, pageSize, count, new ArrayList<T>());
}
/**
* 构造方法
* @param pageNo 当前页码
* @param pageSize 分页大小
* @param count 数据条数
* @param list 本页数据对象列表
*/
public Page(int pageNo, int pageSize, long count, List<T> list) {
this.setCount(count);
this.setPageNo(pageNo);
this.pageSize = pageSize;
this.list = list;
}
/**
* 初始化参数
*/
public void initialize(){
//1
this.first = 1;
this.last = (int)(count / (this.pageSize < 1 ? 20 : this.pageSize) + first - 1);
if (this.count % this.pageSize != 0 || this.last == 0) {
this.last++;
}
if (this.last < this.first) {
this.last = this.first;
}
if (this.pageNo <= 1) {
this.pageNo = this.first;
this.firstPage=true;
}
if (this.pageNo >= this.last) {
this.pageNo = this.last;
this.lastPage=true;
}
if (this.pageNo < this.last - 1) {
this.next = this.pageNo + 1;
} else {
this.next = this.last;
}
if (this.pageNo > 1) {
this.prev = this.pageNo - 1;
} else {
this.prev = this.first;
}
//2
if (this.pageNo < this.first) {// 如果当前页小于首页
this.pageNo = this.first;
}
if (this.pageNo > this.last) {// 如果当前页大于尾页
this.pageNo = this.last;
}
}
/**
* 获取设置总数
* @return
*/
public long getCount() {
return count;
}
/**
* 设置数据总数
* @param count
*/
public void setCount(long count) {
this.count = count;
if (pageSize >= count){
pageNo = 1;
}
}
/**
* 获取当前页码
* @return
*/
public int getPageNo() {
return pageNo;
}
/**
* 设置当前页码
* @param pageNo
*/
public void setPageNo(int pageNo) {
this.pageNo = pageNo;
}
/**
* 获取页面大小
* @return
*/
public int getPageSize() {
return pageSize;
}
/**
* 设置页面大小(最大500)
* @param pageSize
*/
public void setPageSize(int pageSize) {
this.pageSize = pageSize <= 0 ? 10 : pageSize;// > 500 ? 500 : pageSize;
}
/**
* 首页索引
* @return
*/
@JsonIgnore
public int getFirst() {
return first;
}
/**
* 尾页索引
* @return
*/
@JsonIgnore
public int getLast() {
return last;
}
/**
* 获取页面总数
* @return getLast();
*/
@JsonIgnore
public int getTotalPage() {
return getLast();
}
/**
* 是否为第一页
* @return
*/
@JsonIgnore
public boolean isFirstPage() {
return firstPage;
}
/**
* 是否为最后一页
* @return
*/
@JsonIgnore
public boolean isLastPage() {
return lastPage;
}
/**
* 上一页索引值
* @return
*/
@JsonIgnore
public int getPrev() {
if (isFirstPage()) {
return pageNo;
} else {
return pageNo - 1;
}
}
/**
* 下一页索引值
* @return
*/
@JsonIgnore
public int getNext() {
if (isLastPage()) {
return pageNo;
} else {
return pageNo + 1;
}
}
/**
* 获取本页数据对象列表
* @return List<T>
*/
public List<T> getList() {
return list;
}
/**
* 设置本页数据对象列表
* @param list
*/
public Page<T> setList(List<T> list) {
this.list = list;
initialize();
return this;
}
/**
* 获取查询排序字符串
* @return
*/
@JsonIgnore
public String getOrderBy() {
// SQL过滤,防止注入
String reg = "(?:')|(?:--)|(/\\*(?:.|[\\n\\r])*?\\*/)|"
+ "(\\b(select|update|and|or|delete|insert|trancate|char|into|substr|ascii|declare|exec|count|master|into|drop|execute)\\b)";
Pattern sqlPattern = Pattern.compile(reg, Pattern.CASE_INSENSITIVE);
if (sqlPattern.matcher(orderBy).find()) {
return "";
}
return orderBy;
}
/**
* 设置查询排序,标准查询有效, 实例: updatedate desc, name asc
*/
public void setOrderBy(String orderBy) {
this.orderBy = orderBy;
}
/**
* 获取点击页码调用的js函数名称
* function ${page.funcName}(pageNo){location="${ctx}/list-${category.id}${urlSuffix}?pageNo="+i;}
* @return
*/
@JsonIgnore
public String getFuncName() {
return funcName;
}
/**
* 设置点击页码调用的js函数名称,默认为page,在一页有多个分页对象时使用。
* @param funcName 默认为page
*/
public void setFuncName(String funcName) {
this.funcName = funcName;
}
/**
* 获取分页函数的附加参数
* @return
*/
@JsonIgnore
public String getFuncParam() {
return funcParam;
}
/**
* 设置分页函数的附加参数
* @return
*/
public void setFuncParam(String funcParam) {
this.funcParam = funcParam;
}
/**
* 设置提示消息,显示在“共n条”之后
* @param message
*/
public void setMessage(String message) {
this.message = message;
}
/**
* 分页是否有效
* @return this.pageSize==-1
*/
@JsonIgnore
public boolean isDisabled() {
return this.pageSize==-1;
}
/**
* 是否进行总数统计
* @return this.count==-1
*/
@JsonIgnore
public boolean isNotCount() {
return this.count==-1;
}
/**
* 获取 Hibernate FirstResult
*/
@JsonIgnore
public int getFirstResult(){
int firstResult = (getPageNo() - 1) * getPageSize();
if (firstResult >= getCount() || firstResult<0) {
firstResult = 0;
}
return firstResult;
}
/**
* 获取 Hibernate MaxResults
*/
@JsonIgnore
public int getMaxResults(){
return getPageSize();
}
}
因为数据库的不同,分页的语言可能会有所不同。为了能够是程序兼容性好,最好把数据库方言的设置抽出,这里列出mysql和oracle的方言设置:
数据库方言配置接口
public interface Dialect {
/**
* 数据库本身是否支持分页当前的分页查询方式
* 如果数据库不支持的话,则不进行数据库分页
*
* @return true:支持当前的分页查询方式
*/
public boolean supportsLimit();
/**
* 将sql转换为分页SQL,分别调用分页sql
*
* @param sql SQL语句
* @param offset 开始条数
* @param limit 每页显示多少纪录条数
* @return 分页查询的sql
*/
public String getLimitString(String sql, int offset, int limit);
}
MySql方言实现类
public class MySqlDialect implements Dialect {
@Override
public boolean supportsLimit() {
return true;
}
@Override
public String getLimitString(String sql, int offset, int limit) {
return getLimitString(sql, offset, Integer.toString(offset),
Integer.toString(limit));
}
/**
* 将sql变成分页sql语句,提供offest及limit使用占位符号(placeholder)替换.
* 如mysql
* dialect.getLimitString("select * from user", 12, ":offset",0,":limit") 将返回
* select * from user limit :offset,:limit
*
* @param sql 实际SQL语句
* @param offset 分页开始纪录条数
* @param offsetPlaceholder 分页开始纪录条数-占位符号
* @param limitPlaceholder 分页纪录条数占位符号
* @return 包含占位符的分页sql
*/
public String getLimitString(String sql, int offset, String offsetPlaceholder, String limitPlaceholder) {
StringBuilder stringBuilder = new StringBuilder(sql);
stringBuilder.append(" limit ");
if (offset > 0) {
stringBuilder.append(offsetPlaceholder).append(",").append(limitPlaceholder);
} else {
stringBuilder.append(limitPlaceholder);
}
return stringBuilder.toString();
}
}
oracle方言实现类
public class OracleDialect implements Dialect {
@Override
public boolean supportsLimit() {
return true;
}
@Override
public String getLimitString(String sql, int offset, int limit) {
return getLimitString(sql, offset, Integer.toString(offset), Integer.toString(limit));
}
/**
* 将sql变成分页sql语句,提供将offset及limit使用占位符号(placeholder)替换.
* <pre>
* 如mysql
* dialect.getLimitString("select * from user", 12, ":offset",0,":limit") 将返回
* select * from user limit :offset,:limit
* </pre>
*
* @param sql 实际SQL语句
* @param offset 分页开始纪录条数
* @param offsetPlaceholder 分页开始纪录条数-占位符号
* @param limitPlaceholder 分页纪录条数占位符号
* @return 包含占位符的分页sql
*/
public String getLimitString(String sql, int offset, String offsetPlaceholder, String limitPlaceholder) {
sql = sql.trim();
boolean isForUpdate = false;
if (sql.toLowerCase().endsWith(" for update")) {
sql = sql.substring(0, sql.length() - 11);
isForUpdate = true;
}
StringBuilder pagingSelect = new StringBuilder(sql.length() + 100);
if (offset > 0) {
pagingSelect.append("select * from ( select row_.*, rownum rownum_ from ( ");
} else {
pagingSelect.append("select * from ( ");
}
pagingSelect.append(sql);
if (offset > 0) {
String endString = offsetPlaceholder + "+" + limitPlaceholder;
pagingSelect.append(" ) row_ where rownum <= "+endString+") where rownum_ > ").append(offsetPlaceholder);
} else {
pagingSelect.append(" ) where rownum <= "+limitPlaceholder);
}
if (isForUpdate) {
pagingSelect.append(" for update");
}
return pagingSelect.toString();
}
}
sql的工具类,主要有总记录查询方法、对占位符的设值和分页语句的生成
public class SQLHelp {
private static final Logger logger = LoggerFactory.getLogger(SQLHelp.class);
/**
* 对SQL参数(?)设值,参考org.apache.ibatis.executor.parameter.DefaultParameterHandler
*
* @param ps 表示预编译的 SQL 语句的对象。
* @param mappedStatement MappedStatement
* @param boundSql SQL
* @param parameterObject 参数对象
* @throws java.sql.SQLException 数据库异常
*/
@SuppressWarnings("unchecked")
public static void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql, Object parameterObject) throws SQLException {
ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
if (parameterMappings != null) {
Configuration configuration = mappedStatement.getConfiguration();
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
MetaObject metaObject = parameterObject == null ? null :
configuration.newMetaObject(parameterObject);
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
PropertyTokenizer prop = new PropertyTokenizer(propertyName);
if (parameterObject == null) {
value = null;
} else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
value = parameterObject;
} else if (boundSql.hasAdditionalParameter(propertyName)) {
value = boundSql.getAdditionalParameter(propertyName);
} else if (propertyName.startsWith(ForEachSqlNode.ITEM_PREFIX) && boundSql.hasAdditionalParameter(prop.getName())) {
value = boundSql.getAdditionalParameter(prop.getName());
if (value != null) {
value = configuration.newMetaObject(value).getValue(propertyName.substring(prop.getName().length()));
}
} else {
value = metaObject == null ? null : metaObject.getValue(propertyName);
}
@SuppressWarnings("rawtypes")
TypeHandler typeHandler = parameterMapping.getTypeHandler();
if (typeHandler == null) {
throw new ExecutorException("There was no TypeHandler found for parameter " + propertyName + " of statement " + mappedStatement.getId());
}
typeHandler.setParameter(ps, i + 1, value, parameterMapping.getJdbcType());
}
}
}
}
/**
* 查询总纪录数
* @param sql SQL语句
* @param connection 数据库连接
* @param mappedStatement mapped
* @param parameterObject 参数
* @param boundSql boundSql
* @return 总记录数
* @throws SQLException sql查询错误
*/
public static int getRowCount(final String sql, final Connection connection, final BoundSql boundSql,
final MappedStatement mappedStatement, final Object parameterObject) throws SQLException {
final String countSql = "select count(1) from (" + sql + ")temp_count";
Connection conn = connection;
PreparedStatement ps = null;
ResultSet rs = null;
try {
if(logger.isDebugEnabled()){
logger.debug("COUNT SQL: " + StringUtils.replaceEach(countSql, new String[]{"\n","\t"}, new String[]{" "," "}));
}
if(connection == null){
conn = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();
}
ps = conn.prepareStatement(countSql);
BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(),
countSql,
boundSql.getParameterMappings(),
parameterObject);
SQLHelp.setParameters(ps,mappedStatement,countBS,parameterObject);
rs = ps.executeQuery();
int count = 0;
if(rs.next()){
count = rs.getInt(1);
}
return count;
}finally {
if (rs != null) {
rs.close();
}
if (ps != null) {
ps.close();
}
if (conn != null) {
conn.close();
}
}
}
/**
* 根据数据库方言,生成特定的分页sql
* @param sql Mapper中的Sql语句
* @param page 分页对象
* @param dialect 方言类型
* @return 分页SQL
*/
public static String generatePageSql(String sql, Page<Object> page, Dialect dialect) {
if (dialect.supportsLimit()) {
return dialect.getLimitString(sql, page.getFirstResult(), page.getMaxResults());
} else {
return sql;
}
}
}
拦截方法类,自定义mybatis需要实现Interceptor接口,并实现该接口的两个方法:plugin、intercept。在plugin方法中我们可以决定是否要进行拦截进而决定要返回一个什么样的目标对象。而intercept方法就是要进行拦截的时候要执行的方法。
@Intercepts({
@Signature(method = "query", type = Executor.class, args = {
MappedStatement.class, Object.class, RowBounds.class,
ResultHandler.class})
})
public class PaginationInterceptor implements Interceptor {
private final Logger logger = LoggerFactory.getLogger(getClass());
protected static final String PAGE = "page";
private Dialect dialect;
/**
* 分页拦截方法
* @param invocation
* @return
* @throws Throwable
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
final MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
//拦截需要分页的SQL
Object parameter = invocation.getArgs()[1];
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
Object parameterObject = boundSql.getParameterObject();
//获取分页参数对象
Page<Object> page = null;
if(parameterObject != null){
page = convertParameter(parameterObject,page);
}
//如果设置了分页对象,则进行分页
if(page != null && page.getPageSize() != -1){
if(StringUtils.isBlank(boundSql.getSql())){
return null;
}
String origin_sql = boundSql.getSql().trim();
//设置总的记录数
page.setCount(SQLHelp.getRowCount(origin_sql,null,boundSql,mappedStatement,parameterObject));
//分页查询
String pageSql = SQLHelp.generatePageSql(origin_sql,page,dialect);
invocation.getArgs()[2] = new RowBounds(RowBounds.NO_ROW_OFFSET,RowBounds.NO_ROW_LIMIT);
BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(),
pageSql,
boundSql.getParameterMappings(),
boundSql.getParameterObject());
MappedStatement newMs = copyFromMappedStatement(mappedStatement,new BoundSqlSqlSource(newBoundSql));
invocation.getArgs()[0] = newMs;
}
return invocation.proceed();
}
private MappedStatement copyFromMappedStatement(MappedStatement ms,
SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),
ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null) {
for (String keyProperty : ms.getKeyProperties()) {
builder.keyProperty(keyProperty);
}
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.cache(ms.getCache());
return builder.build();
}
public static class BoundSqlSqlSource implements SqlSource {
BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
/**
* 对参数进行转换和检查
* @param parameterObject 参数对象
* @param page 分页对象
* @return 分页对象
* @throws NoSuchFieldException 无法找到参数
*/
@SuppressWarnings("unchecked")
protected static Page<Object> convertParameter(Object parameterObject, Page<Object> page) {
try{
if (parameterObject instanceof Page) {
return (Page<Object>) parameterObject;
} else {
return (Page<Object>) ReflectHelper.getFieldValue(parameterObject, PAGE);
}
}catch (Exception e) {
return null;
}
}
@Override
public Object plugin(Object o) {
return Plugin.wrap(o, this);
}
@Override
public void setProperties(Properties properties) {
//初始化方言实现类
dialect = new MySqlDialect();
}
}
可以看到在该类上使用了@Intercepts
注解,该注解主要是定义拦截点。该Interceptor将拦截Executor接口中参数类型为MappedStatement、Object、RowBounds和ResultHandler的query方法。
最后需要在mybatis的配置文件中注册拦截器:
<plugins>
<plugin interceptor="com.wqh.blog.handle.PaginationInterceptor" />
</plugins>