本文来探寻一下 PageHelper
分页和 count 的原理,数据库是 MySQL。
基本思路是:
PageHelper
向Mybatis
注册处理分页和 count 的拦截器PageInterceptor
- 通过
PageHelper.startPage()
方法把分页相关的参数放到ThreadLcoal
中 - Mybatis 执行 SQL 过程中会调用拦截器
- 根据查询 SQL 构建 count SQL
- 从 ThreadLcoal 拿出分页信息,在查询 SQL 后面拼接
limit ?, ?
- 清空 ThreadLcoal
- 使用
Page
创建PageInfo
对象
环境准备
maven
<dependency>
<groupId>com.github.pagehelper</groupId>
<artifactId>pagehelper-spring-boot-starter</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
<version>2.1.4</version>
</dependency>
User.java
@Data
public class User {
private Integer id;
private String name;
}
UserMapper.java
public interface UserMapper {
@Select("select id, name from user")
List<User> selectUsers();
}
UserController.java
@RestController
public class UserController {
@Autowired
private UserMapper userMapper;
@GetMapping("/users")
public Object listUser(@RequestParam("pageNum") Integer pageNum,
@RequestParam("pageSize") Integer pageSize) {
PageHelper.startPage(pageNum, pageSize);
List<User> users = userMapper.selectUsers();
return PageInfo.of(users);
}
}
127.0.0.1:8080/users?pageNum=2&pageSize=3
{
"total": 26,
"list": [
{
"id": 4,
"name": "d"
},
{
"id": 5,
"name": "e"
},
{
"id": 6,
"name": "f"
}
],
"pageNum": 2,
"pageSize": 3,
"size": 3,
"startRow": 4,
"endRow": 6,
"pages": 9,
"prePage": 1,
"nextPage": 3,
"isFirstPage": false,
"isLastPage": false,
"hasPreviousPage": true,
"hasNextPage": true,
"navigatePages": 8,
"navigatepageNums": [
1,
2,
3,
4,
5,
6,
7,
8
],
"navigateFirstPage": 1,
"navigateLastPage": 8
}
分析
将分页信息放入 LOCAL_PAGE
PageHelper
继承了 PageMethod
,且实现了 Dialect, BoundSqlInterceptor.Chain
:
public class PageHelper extends PageMethod implements Dialect, BoundSqlInterceptor.Chain {
...
}
我们使用分页时,大多都是使用 PageHelper.startPage(pageNum, pageSize)
。
public static <E> Page<E> startPage(int pageNum, int pageSize) {
return startPage(pageNum, pageSize, DEFAULT_COUNT);
}
startPage()
方法最后都会调用以下重载方法。构建一个 Page
对象,存储分页相关的参数、设置,最后调用 setLocalPage(page);
将其放入 ThreadLocal
。
/**
* 开始分页
*
* @param pageNum 页码
* @param pageSize 每页显示数量
* @param count 是否进行count查询
* @param reasonable 分页合理化,null时用默认配置
* @param pageSizeZero true且pageSize=0时返回全部结果,false时分页,null时用默认配置
*/
public static <E> Page<E> startPage(int pageNum, int pageSize, boolean count, Boolean reasonable, Boolean pageSizeZero) {
Page<E> page = new Page<E>(pageNum, pageSize, count);
page.setReasonable(reasonable);
page.setPageSizeZero(pageSizeZero);
//当已经执行过orderBy的时候
Page<E> oldPage = getLocalPage();
if (oldPage != null && oldPage.isOrderByOnly()) {
page.setOrderBy(oldPage.getOrderBy());
}
// 放入 ThreadLocal
setLocalPage(page);
return page;
}
这里的 LOCAL_PAGE
就是前面说的存放 Page
信息的 ThreadLocal
:
public abstract class PageMethod {
// 重要的 ThreadLocal
protected static final ThreadLocal<Page> LOCAL_PAGE = new ThreadLocal<Page>();
protected static boolean DEFAULT_COUNT = true;
protected static void setLocalPage(Page page) {
LOCAL_PAGE.set(page);
}
public static <T> Page<T> getLocalPage() {
return LOCAL_PAGE.get();
}
public static void clearPage() {
LOCAL_PAGE.remove();
}
...
}
如何使用分页信息:count、分页
PageInterceptor
实现了 org.apache.ibatis.plugin.Interceptor
接口,mybatis 在执行查询方法的时候(method = “query”)会调用本拦截器。
/**
* Mybatis - 通用分页拦截器
*/
@SuppressWarnings({"rawtypes", "unchecked"})
@Intercepts(
{
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
}
)
public class PageInterceptor implements Interceptor {
// 默认为 Pagehelper
private volatile Dialect dialect;
// count 方法的后缀
private String countSuffix = "_COUNT";
// count 查询的缓存,只用于
// 本例中 key 为 com.example.pagehelper.dao.UserMapper.selectUsers_COUNT
protected Cache<String, MappedStatement> msCountMap = null;
//
private String default_dialect_class = "com.github.pagehelper.PageHelper";
...
}
发出 Http 请求:127.0.0.1:8080/users?pageNum=2&pageSize=3
invocation.getArgs()
获取的参数就是 args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
和 args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
,所有才有下面的 if (args.length == 4)
判断。
boundSql 中存储原始的查询 SQL:
PageInterceptor.intercept()
@Override
public Object intercept(Invocation invocation) throws Throwable {
try {
// 获取方法参数
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameter = args[1];
RowBounds rowBounds = (RowBounds) args[2];
ResultHandler resultHandler = (ResultHandler) args[3];
Executor executor = (Executor) invocation.getTarget();
CacheKey cacheKey;
BoundSql boundSql;
//由于逻辑关系,只会进入一次
if (args.length == 4) {
//4 个参数时
// 拿到原始的查询 SQL
boundSql = ms.getBoundSql(parameter);
cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
} else {
//6 个参数时
cacheKey = (CacheKey) args[4];
boundSql = (BoundSql) args[5];
}
checkDialectExists();
//对 boundSql 的拦截处理
// 实际什么都没做,原样返回了
if (dialect instanceof BoundSqlInterceptor.Chain) {
boundSql = ((BoundSqlInterceptor.Chain) dialect).doBoundSql(BoundSqlInterceptor.Type.ORIGINAL, boundSql, cacheKey);
}
List resultList;
//调用方法判断是否需要进行分页,如果不需要,直接返回结果
if (!dialect.skip(ms, parameter, rowBounds)) {
//判断是否需要进行 count 查询
if (dialect.beforeCount(ms, parameter, rowBounds)) {
// 查询总数
// 见 PageInterceptor.count()
Long count = count(executor, ms, parameter, rowBounds, null, boundSql);
//处理查询总数,返回 true 时继续分页查询,false 时直接返回
if (!dialect.afterCount(count, parameter, rowBounds)) {
//当查询总数为 0 时,直接返回空的结果
return dialect.afterPage(new ArrayList(), parameter, rowBounds);
}
}
// 执行分页查询
resultList = ExecutorUtil.pageQuery(dialect, executor,
ms, parameter, rowBounds, resultHandler, boundSql, cacheKey);
} else {
//rowBounds用参数值,不使用分页插件处理时,仍然支持默认的内存分页
resultList = executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
}
// 将count、分页 信息放入 ThreadLocal
return dialect.afterPage(resultList, parameter, rowBounds);
} finally {
if(dialect != null){
dialect.afterAll();
}
}
}
count
PageInterceptor.count()
private Long count(Executor executor, MappedStatement ms, Object parameter,
RowBounds rowBounds, ResultHandler resultHandler,
BoundSql boundSql) throws SQLException {
// countMsId = "com.example.pagehelper.dao.UserMapper.selectUsers_COUNT"
String countMsId = ms.getId() + countSuffix;
Long count;
//先判断是否存在手写的 count 查询
MappedStatement countMs = ExecutorUtil.getExistedMappedStatement(ms.getConfiguration(), countMsId);
if (countMs != null) {
// 直接执行手写的 count 查询
count = ExecutorUtil.executeManualCount(executor, countMs, parameter, boundSql, resultHandler);
} else {
// 先从缓存中查
if (msCountMap != null) {
countMs = msCountMap.get(countMsId);
}
// 缓存中没有,然后自动创建,并放入缓存
if (countMs == null) {
//根据当前的 ms 创建一个返回值为 Long 类型的 ms
countMs = MSUtils.newCountMappedStatement(ms, countMsId);
if (msCountMap != null) {
// 放入缓存
msCountMap.put(countMsId, countMs);
}
}
// 执行 count 查询
// 见 ExecutorUtil.executeAutoCount()
count = ExecutorUtil.executeAutoCount(this.dialect, executor, countMs, parameter,
boundSql, rowBounds, resultHandler);
}
return count;
}
ExecutorUtil.executeAutoCount()
public static Long executeAutoCount(Dialect dialect, Executor executor, MappedStatement countMs,
Object parameter, BoundSql boundSql,
RowBounds rowBounds, ResultHandler resultHandler) throws SQLException {
Map<String, Object> additionalParameters = getAdditionalParameter(boundSql);
//创建 count 查询的缓存 key
CacheKey countKey = executor.createCacheKey(countMs, parameter, RowBounds.DEFAULT, boundSql);
//调用方言获取 count sql:SELECT count(0) FROM user
// 见 PageHelper.getCountSql()
String countSql = dialect.getCountSql(countMs, boundSql, parameter, rowBounds, countKey);
//countKey.update(countSql);
BoundSql countBoundSql = new BoundSql(countMs.getConfiguration(), countSql, boundSql.getParameterMappings(), parameter);
//当使用动态 SQL 时,可能会产生临时的参数,这些参数需要手动设置到新的 BoundSql 中
for (String key : additionalParameters.keySet()) {
countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
//对 boundSql 的拦截处理
if (dialect instanceof BoundSqlInterceptor.Chain) {
countBoundSql = ((BoundSqlInterceptor.Chain) dialect).doBoundSql(BoundSqlInterceptor.Type.COUNT_SQL, countBoundSql, countKey);
}
//执行 count 查询
Object countResultList = executor.query(countMs, parameter, RowBounds.DEFAULT, resultHandler, countKey, countBoundSql);
Long count = (Long) ((List) countResultList).get(0);
return count;
}
PageHelper.getCountSql()
@Override
public String getCountSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey) {
// AbstractHelperDialect.getCountSql()
return autoDialect.getDelegate().getCountSql(ms, boundSql, parameterObject, rowBounds, countKey);
}
AbstractHelperDialect.getCountSql()
@Override
public String getCountSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey) {
Page<Object> page = getLocalPage();
String countColumn = page.getCountColumn();
if (StringUtil.isNotEmpty(countColumn)) {
return countSqlParser.getSmartCountSql(boundSql.getSql(), countColumn);
}
return countSqlParser.getSmartCountSql(boundSql.getSql());
}
CountSqlParser.getSmartCountSql()
// sql = "select id, name from user";
// name = "0";
public String getSmartCountSql(String sql, String name) {
//解析SQL
Statement stmt = null;
//特殊sql不需要去掉order by时,使用注释前缀
if(sql.indexOf(KEEP_ORDERBY) >= 0){
return getSimpleCountSql(sql, name);
}
try {
stmt = CCJSqlParserUtil.parse(sql);
} catch (Throwable e) {
//无法解析的用一般方法返回count语句
return getSimpleCountSql(sql, name);
}
Select select = (Select) stmt;
SelectBody selectBody = select.getSelectBody();
try {
//处理body-去order by
processSelectBody(selectBody);
} catch (Exception e) {
//当 sql 包含 group by 时,不去除 order by
return getSimpleCountSql(sql, name);
}
//处理with-去order by
processWithItemsList(select.getWithItemsList());
//处理为count查询
sqlToCount(select, name);
// SELECT count(0) FROM user
String result = select.toString();
return result;
}
CountSqlParser.sqlToCount()
public void sqlToCount(Select select, String name) {
SelectBody selectBody = select.getSelectBody();
// 是否能简化count查询
List<SelectItem> COUNT_ITEM = new ArrayList<SelectItem>();
// count(0)
COUNT_ITEM.add(new SelectExpressionItem(new Column("count(" + name +")")));
// 是否可以用简单的count查询方式
if (selectBody instanceof PlainSelect && isSimpleCount((PlainSelect) selectBody)) {
// 将 id, name 换成 count(0),最终SQL为 SELECT count(0) FROM user
((PlainSelect) selectBody).setSelectItems(COUNT_ITEM);
} else {
PlainSelect plainSelect = new PlainSelect();
SubSelect subSelect = new SubSelect();
subSelect.setSelectBody(selectBody);
subSelect.setAlias(TABLE_ALIAS);
plainSelect.setFromItem(subSelect);
plainSelect.setSelectItems(COUNT_ITEM);
select.setSelectBody(plainSelect);
}
}
另一种 CountSQL :CountSqlParser.getSimpleCountSql()
将原始 SQL 作为子查询,在外层拼上 select count("0") from ( subQuery ) tmp_count
:
public String getSimpleCountSql(final String sql) {
return getSimpleCountSql(sql, "0");
}
public String getSimpleCountSql(final String sql, String name) {
StringBuilder stringBuilder = new StringBuilder(sql.length() + 40);
stringBuilder.append("select count(");
stringBuilder.append(name);
stringBuilder.append(") from (");
stringBuilder.append(sql);
stringBuilder.append(") tmp_count");
return stringBuilder.toString();
}
分页
ExecutorUtil.pageQuery
public static <E> List<E> pageQuery(Dialect dialect, Executor executor, MappedStatement ms, Object parameter,
RowBounds rowBounds, ResultHandler resultHandler,
BoundSql boundSql, CacheKey cacheKey) throws SQLException {
//判断是否需要进行分页查询
if (dialect.beforePage(ms, parameter, rowBounds)) {
//生成分页的缓存 key
CacheKey pageKey = cacheKey;
//处理参数对象
parameter = dialect.processParameterObject(ms, parameter, boundSql, pageKey);
//调用方言获取分页 sql,这里是重点,是添加 limit 的地方
// pageSql = select id, name from user LIMIT ?, ?
String pageSql = dialect.getPageSql(ms, boundSql, parameter, rowBounds, pageKey);
BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql, boundSql.getParameterMappings(), parameter);
Map<String, Object> additionalParameters = getAdditionalParameter(boundSql);
//设置动态参数
for (String key : additionalParameters.keySet()) {
pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
//对 boundSql 的拦截处理
if (dialect instanceof BoundSqlInterceptor.Chain) {
pageBoundSql = ((BoundSqlInterceptor.Chain) dialect).doBoundSql(BoundSqlInterceptor.Type.PAGE_SQL, pageBoundSql, pageKey);
}
//执行分页查询
return executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, pageKey, pageBoundSql);
} else {
//不执行分页的情况下,也不执行内存分页
return executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, cacheKey, boundSql);
}
}
AbstractHelperDialect.beforePage
public boolean beforePage(MappedStatement ms, Object parameterObject, RowBounds rowBounds) {
Page page = getLocalPage();
if (page.isOrderByOnly() || page.getPageSize() > 0) {
return true;
}
return false;
}
AbstractHelperDialect.beforePage
public String getPageSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey) {
// sql = select id, name from user
String sql = boundSql.getSql();
Page page = getLocalPage();
//支持 order by
String orderBy = page.getOrderBy();
if (StringUtil.isNotEmpty(orderBy)) {
pageKey.update(orderBy);
sql = OrderByParser.converToOrderBySql(sql, orderBy);
}
if (page.isOrderByOnly()) {
return sql;
}
return getPageSql(sql, page, pageKey);
}
AbstractHelperDialect.getPageSql
// sql = select id, name from user
public String getPageSql(String sql, Page page, CacheKey pageKey) {
StringBuilder sqlBuilder = new StringBuilder(sql.length() + 14);
sqlBuilder.append(sql);
// 拼接 LIMIT
if (page.getStartRow() == 0) {
sqlBuilder.append("\n LIMIT ? ");
} else {
sqlBuilder.append("\n LIMIT ?, ? ");
}
// sql = select id, name from user LIMIT ?, ?
return sqlBuilder.toString();
}
AbstractHelperDialect.afterPage
public Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds) {
Page page = getLocalPage();
// 不是 PageHelper 的分页,直接返回结果
if (page == null) {
return pageList;
}
// 将分页结果放入 ThreadLocal 中
page.addAll(pageList);
if (!page.isCount()) {
page.setTotal(-1);
} else if ((page.getPageSizeZero() != null && page.getPageSizeZero()) && page.getPageSize() == 0) {
page.setTotal(pageList.size());
} else if (page.isOrderByOnly()) {
page.setTotal(pageList.size());
}
return page;
}
PageHelper.afterAll
public void afterAll() {
//这个方法即使不分页也会被执行,所以要判断 null
AbstractHelperDialect delegate = autoDialect.getDelegate();
if (delegate != null) {
delegate.afterAll();
autoDialect.clearDelegate();
}
// 清楚 ThreadLocal 中的分页信息
clearPage();
}
使用 Page 创建 PageInfo 对象
如上图,users
的实际类型不是 List
,而是 Page
,就是 ThreadLocal
中的那个对象,PageHelper
扩展了 ArrayList
,添加了一系列分页相关的字段、方法:
public class Page<E> extends ArrayList<E> implements Closeable {
private static final long serialVersionUID = 1L;
/**
* 页码,从1开始
*/
private int pageNum;
/**
* 页面大小
*/
private int pageSize;
/**
* 起始行
*/
private long startRow;
/**
* 末行
*/
private long endRow;
/**
* 总数
*/
private long total;
/**
* 总页数
*/
private int pages;
/**
* 包含count查询
*/
private boolean count = true;
/**
* 分页合理化
*/
private Boolean reasonable;
/**
* 当设置为true的时候,如果pagesize设置为0(或RowBounds的limit=0),就不执行分页,返回全部结果
*/
private Boolean pageSizeZero;
/**
* 进行count查询的列名
*/
private String countColumn;
/**
* 排序
*/
private String orderBy;
/**
* 只增加排序
*/
private boolean orderByOnly;
...
}
new PageInfo()
PageSerializable
public PageSerializable(List<T> list) {
// 分页结果
this.list = list;
// 总条数
if(list instanceof Page){
this.total = ((Page)list).getTotal();
} else {
this.total = list.size();
}
}
PageInfo
public PageInfo(List<T> list, int navigatePages) {
// 上面 PageSerializable(list);
super(list);
if (list instanceof Page) {
Page page = (Page) list;
// 当前页
this.pageNum = page.getPageNum();
// 每页的数量
this.pageSize = page.getPageSize();
// 总页数
this.pages = page.getPages();
// 当前页的数量
this.size = page.size();
// 起始行号、结束行号
//由于结果是>startRow的,所以实际的需要+1
if (this.size == 0) {
this.startRow = 0;
this.endRow = 0;
} else {
this.startRow = page.getStartRow() + 1;
//计算实际的endRow(最后一页的时候特殊)
this.endRow = this.startRow - 1 + this.size;
}
} else if (list instanceof Collection) {
this.pageNum = 1;
this.pageSize = list.size();
this.pages = this.pageSize > 0 ? 1 : 0;
this.size = list.size();
this.startRow = 0;
this.endRow = list.size() > 0 ? list.size() - 1 : 0;
}
if (list instanceof Collection) {
calcByNavigatePages(navigatePages);
}
}