mybatis 复合查询分页
关于mybatis分页
在使用mybatis进行开发。分页通常是交由mybatis分页拦截器进行处理的。通常处理逻辑是,拦截到查询方法后,首先修改查询语句,用同样的条件和参数计算出总页数,然后根据不同的数据库添加相应的分页语法。达到分页的效果。这种方法在单表查询或者一对一查询就非常有用。但是对于一对多或者多对多的复杂查询中,就会显得捉衿见肘(通常是有错误的)。本文就是为解决此问题而生的。希望对大家有用。
表设计和实体设计
用户表
列 | 说明 | 是否主键 |
---|---|---|
id | 主键 | 是 |
username | 登陆名 | – |
password | 密码 | – |
角色表
列 | 说明 | 是否主键 |
---|---|---|
id | 主键 | 是 |
role_name | 角色名 | – |
角色与用户对照表
列 | 说明 | 是否主键 |
---|---|---|
id | 主键 | 是 |
role_id | 所属角色 | – |
user_id | 所属用户 | – |
分别对应的实体类(分别似乎User,Role)
现在的需求是获取用户列表并且查看每一个用户所拥有的权限。根据需求定义如下返回对象 UserVo
@Data
public class UserVo{
private String userId ;
private User user ;
private List<Role> roles ;
}
查询接口为
public IPage<UserVo> getUsers(){
......
}
存在的问题
根据上面的设计,如果采用默认的分页方法,获取分页结果的时候就会出现查询某个用户缺少角色数据,或者分页不全的现象。遇到这个问题的时候,思路是,弄懂现有的分页是怎样实现的。根据对mybatis-plus分页代码的阅读。分页是通过PaginationInterceptor对查询语句进行拦截,然后根据不同的数据库添加相应的分页语句。
那么,类似接口这种一对多的查询的语法应该怎么改变呢?我想到了,首先对主表进行分页,例如本例子中的User表,然后根据分页查询的主键结果,对查询SQL 进行改造,加上In操作。所以,需要做两件事儿。第一,修改总页数句计算的sql生成方式。第二,修改查询语句的使其满足分页。
自定义分页总数计算
package com.cecdat.common.aop;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.parser.SqlInfo;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.plugins.pagination.optimize.JsqlParserCountOptimize;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.reflection.MetaObject;
import java.util.ArrayList;
import java.util.List;
public class CountSqlInterceptor implements ISqlParser {
private static final List<SelectItem> COUNT_SELECT_ITEM = countSelectItem();
/**
* 日志
*/
private final Log logger = LogFactory.getLog(CountSqlInterceptor.class);
/**
* 获取jsqlparser中count的SelectItem
*/
private static List<SelectItem> countSelectItem() {
Function function = new Function();
function.setName("COUNT");
List<Expression> expressions = new ArrayList<>();
LongValue longValue = new LongValue(1);
ExpressionList expressionList = new ExpressionList();
expressions.add(longValue);
expressionList.setExpressions(expressions);
function.setParameters(expressionList);
List<SelectItem> selectItems = new ArrayList<>();
SelectExpressionItem selectExpressionItem = new SelectExpressionItem(function);
selectItems.add(selectExpressionItem);
return selectItems;
}
@Override
public SqlInfo parser(MetaObject metaObject, String sql) {
if (logger.isDebugEnabled()) {
logger.debug("JsqlParserCountOptimize sql=" + sql);
}
SqlInfo sqlInfo = SqlInfo.newInstance();
try {
Select selectStatement = (Select) CCJSqlParserUtil.parse(sql);
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
Distinct distinct = plainSelect.getDistinct();
List<Expression> groupBy = plainSelect.getGroupByColumnReferences();
List<OrderByElement> orderBy = plainSelect.getOrderByElements();
List<Join> joins = plainSelect.getJoins();
// 添加包含groupBy 不去除orderBy
if (CollectionUtils.isEmpty(groupBy) && CollectionUtils.isNotEmpty(orderBy)) {
plainSelect.setOrderByElements(null);
sqlInfo.setOrderBy(false);
}
//#95 Github, selectItems contains #{} ${}, which will be translated to ?, and it may be in a function: power(#{myInt},2)
for (SelectItem item : plainSelect.getSelectItems()) {
if (item.toString().contains(StringPool.QUESTION_MARK)) {
return sqlInfo.setSql(SqlParserUtils.getOriginalCountSql(selectStatement.toString()));
}
}
// 包含 distinct、groupBy不优化
if (distinct != null || CollectionUtils.isNotEmpty(groupBy)) {
return sqlInfo.setSql(SqlParserUtils.getOriginalCountSql(selectStatement.toString()));
}
/**
* 此处很关键,判断如果有作连接,则去掉作连接内容,然后根据主表进行分页总数的计算
* 如果有join并且是左连接,则说明此为主表
*/
if(joins !=null && joins.get(0).isLeft()){
plainSelect.setJoins(null);
}
// 优化 SQL
plainSelect.setSelectItems(COUNT_SELECT_ITEM);
return sqlInfo.setSql(selectStatement.toString());
} catch (Throwable e) {
// 无法优化使用原 SQL
return sqlInfo.setSql(SqlParserUtils.getOriginalCountSql(sql));
}
}
}
package com.cecdat.common.aop;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.MybatisDefaultParameterHandler;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.parser.SqlInfo;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import lombok.Setter;
import lombok.experimental.Accessors;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;
import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Setter
@Accessors(chain = true)
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class MyPaginationInterceptor extends PaginationInterceptor {
/**
* COUNT SQL 解析
*/
private ISqlParser countSqlParser;
/**
* 溢出总页数,设置第一页
*/
private boolean overflow = false;
/**
* 单页限制 500 条,小于 0 如 -1 不受限制
*/
private long limit = 500L;
/**
* 方言类型
*/
private String dialectType;
/**
* 方言实现类
*/
private String dialectClazz;
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
// SQL 解析
this.sqlParser(metaObject);
// 先判断是不是SELECT操作 (2019-04-10 00:37:31 跳过存储过程)
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType()
|| StatementType.CALLABLE == mappedStatement.getStatementType()) {
return invocation.proceed();
}
// 针对定义了rowBounds,做为mapper接口方法的参数
BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
Object paramObj = boundSql.getParameterObject();
// 判断参数里是否有page对象
IPage<?> page = null;
if (paramObj instanceof IPage) {
page = (IPage<?>) paramObj;
} else if (paramObj instanceof Map) {
for (Object arg : ((Map<?, ?>) paramObj).values()) {
if (arg instanceof IPage) {
page = (IPage<?>) arg;
break;
}
}
}
/*
* 不需要分页的场合,如果 size 小于 0 返回结果集
*/
if (null == page || page.getSize() < 0) {
return invocation.proceed();
}
/*
* 处理单页条数限制
*/
if (limit > 0 && limit <= page.getSize()) {
page.setSize(limit);
}
String originalSql = boundSql.getSql();
Connection connection = (Connection) invocation.getArgs()[0];
DbType dbType = StringUtils.isNotEmpty(dialectType) ? DbType.getDbType(dialectType)
: JdbcUtils.getDbType(connection.getMetaData().getURL());
boolean orderBy = true;
if (page.isSearchCount()) {
SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), countSqlParser, originalSql);
orderBy = sqlInfo.isOrderBy();
this.queryTotal(overflow, sqlInfo.getSql(), mappedStatement, boundSql, page, connection);
if (page.getTotal() <= 0) {
return null;
}
}
String buildSql = concatOrderBy(originalSql, page, orderBy);
Configuration configuration = mappedStatement.getConfiguration();
List<ParameterMapping> mappings = new ArrayList<>(boundSql.getParameterMappings());
Map<String, Object> additionalParameters = (Map<String, Object>) metaObject.getValue("delegate.boundSql.additionalParameters");
Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
List<Join> joins = plainSelect.getJoins();
//此处判断是否为多对多或者一对多的查询
if(joins!=null && joins.size()>0&&joins.get(0).isLeft()){
String exceptPagingWhereSql = this.getCountIdSql(plainSelect);
DialectModel countModel = DialectFactory.buildPaginationSql(page, exceptPagingWhereSql, dbType, dialectClazz);
countModel.consumers(mappings, configuration, additionalParameters);
exceptPagingWhereSql= countModel.getDialectSql() ;
BoundSql bsql = new BoundSql(configuration,exceptPagingWhereSql,mappings,boundSql.getParameterObject());
additionalParameters.forEach((key,value)->{
bsql.setAdditionalParameter(key,value);
});
List<String> ids = this.queryIds(exceptPagingWhereSql, mappedStatement, bsql, page, connection,mappings);
return specialPaging(metaObject,originalSql,ids,invocation,boundSql);
}
DialectModel model = DialectFactory.buildPaginationSql(page, buildSql, dbType, dialectClazz);
model.consumers(mappings, configuration, additionalParameters);
metaObject.setValue("delegate.boundSql.sql", model.getDialectSql());
metaObject.setValue("delegate.boundSql.parameterMappings", mappings);
return invocation.proceed();
}
/**
* 联合查询特殊分页处理方法
* @param metaObject
* @param originalSql
* @param ids
* @param invocation
* @param boundSql
* @return
* @throws JSQLParserException
* @throws InvocationTargetException
* @throws IllegalAccessException
*/
private Object specialPaging(MetaObject metaObject, String originalSql, List<String> ids, Invocation invocation, BoundSql boundSql) throws JSQLParserException, InvocationTargetException, IllegalAccessException {
Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
FromItem fromItem = plainSelect.getFromItem();
String aname = fromItem.getAlias()==null?null:fromItem.getAlias().getName() ;
Expression where = plainSelect.getWhere();
ExpressionList list= new ExpressionList();
List<Expression> expressions = new ArrayList<>();
ids.forEach(item->{
StringValue stringValue= new StringValue(item);
expressions.add(stringValue);
});
list.setExpressions(expressions);
Column id = new Column(aname==null?fromItem.toString()+".id":aname+".id");
InExpression inExpression = new InExpression(id,list);
AndExpression andExpression = new AndExpression(where,inExpression);
plainSelect.setWhere(andExpression);
metaObject.setValue("delegate.boundSql.sql", plainSelect.toString());
List<ParameterMapping> mappings = new ArrayList<>(boundSql.getParameterMappings());
metaObject.setValue("delegate.boundSql.parameterMappings", mappings);
return invocation.proceed();
}
public String getCountIdSql(PlainSelect plainSelect){
String exceptPagingWhereSql=null ;
plainSelect.setJoins(null);
FromItem fromItem = plainSelect.getFromItem();
Alias alias = fromItem.getAlias();
List<SelectItem> items = new ArrayList<>();
SelectExpressionItem selectExpressionItem = new SelectExpressionItem();
selectExpressionItem.setExpression(new Column("id"));
items.add(selectExpressionItem);
plainSelect.setSelectItems(items);
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
if(orderByElements==null||orderByElements.size()<=0){
orderByElements = new ArrayList<>();
OrderByElement orderByElement = new OrderByElement() ;
orderByElement.setAsc(false);
orderByElement.setExpression(new Column("create_date"));
orderByElements.add(orderByElement);
plainSelect.setOrderByElements(orderByElements);
}
exceptPagingWhereSql = plainSelect.toString();
return exceptPagingWhereSql ;
}
/**
* 查询总记录条数
* @param sql count sql
* @param mappedStatement MappedStatement
* @param boundSql BoundSql
* @param page IPage
* @param connection Connection
* @param mappings
*/
protected List<String> queryIds(String sql, MappedStatement mappedStatement, BoundSql boundSql, IPage<?> page, Connection connection, List<ParameterMapping> mappings) {
List<String> ids = new ArrayList<>();
try (PreparedStatement statement = connection.prepareStatement(sql)) {
// boundSql.setAdditionalParameter(additionalParameters);
Object parameterObject = boundSql.getParameterObject();
DefaultParameterHandler parameterHandler = new MybatisDefaultParameterHandler(mappedStatement, parameterObject, boundSql);
parameterHandler.setParameters(statement);
try (ResultSet resultSet = statement.executeQuery()) {
while(resultSet.next()){
ids.add(resultSet.getString(1));
}
}
/*
* 溢出总页数,设置第一页
*/
return ids;
} catch (Exception e) {
throw ExceptionUtils.mpe("Error: Method queryTotal execution error of sql : \n %s \n", e, sql);
}
}
}
然后修改启动类,用自定义的分页方法代替默认的分页方法
/**
* 分页配置
* @return
*/
@Bean
public PaginationInterceptor paginationInterceptor() {
CountSqlInterceptor countSqlInterceptor = new CountSqlInterceptor() ;
MyPaginationInterceptor paginationInterceptor = new MyPaginationInterceptor();
paginationInterceptor.setCountSqlParser(countSqlInterceptor);
return paginationInterceptor;
}
写在最后
这种修改方法是建立在所有的表都有id这个列作为主键。并且sql写法为left join的方式。