参考:url
环境:spring boot
@Configuration
public class MyBatisConfiguration {
private final Log LOG = LogFactory.getLog(MyBatisConfiguration.class);
@Bean
public PageInterceptor pageInterceptor() {
LOG.info("注册MyBatis分页插件pageInterceptor");
PageInterceptor pageInterceptor = new PageInterceptor();
Properties properties = new Properties();
properties.setProperty("databaseType", "mysql");
pageInterceptor.setProperties(properties);
return pageInterceptor;
}
}
以上代码是将分页插件注册到spring中
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import co.dc.saas.core.po.Page;
/**
*
* 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。
* 利用拦截器实现Mybatis分页的原理:
* 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象,而且对应的Sql语句
* 是在Statement之前产生的,所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的
* prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用
* StatementHandler对象的prepare方法,即调用invocation.proceed()。
* 对于分页而言,在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设
* 置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计。
*
*/
@Intercepts( {
@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class,Integer.class}) })
public class PageInterceptor implements Interceptor {
private String databaseType;//数据库类型,不同的数据库有不同的分页方法
/**
* 拦截后要执行的方法
*/
public Object intercept(Invocation invocation) throws Throwable {
//对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler,
//BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,
//SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是
//处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个
//StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、
//PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。
//我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候
//是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
//通过反射获取到当前RoutingStatementHandler对象的delegate属性
StatementHandler delegate = (StatementHandler)ReflectUtil.getFieldValue(handler, "delegate");
//获取到当前StatementHandler的 boundSql,这里不管是调用handler.getBoundSql()还是直接调用delegate.getBoundSql()结果是一样的,因为之前已经说过了
//RoutingStatementHandler实现的所有StatementHandler接口方法里面都是调用的delegate对应的方法。
BoundSql boundSql = delegate.getBoundSql();
//拿到当前绑定Sql的参数对象,就是我们在调用对应的Mapper映射语句时所传入的参数对象
Object obj = boundSql.getParameterObject();
//这里我们简单的通过传入的是Page对象就认定它是需要进行分页操作的。
if (obj instanceof Page<?>) {
Page<?> page = (Page<?>) obj;
//通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
MappedStatement mappedStatement = (MappedStatement)ReflectUtil.getFieldValue(delegate, "mappedStatement");
//拦截到的prepare方法参数是一个Connection对象
Connection connection = (Connection)invocation.getArgs()[0];
//获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句
String sql = boundSql.getSql();
//给当前的page参数对象设置总记录数
setTotalRecord(page,
mappedStatement, connection);
//获取分页Sql语句
String pageSql = getPageSql(page, sql);
//利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
}
return invocation.proceed();
}
/**
* 拦截器对应的封装原始对象的方法
*/
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 设置注册拦截器时设定的属性
*/
public void setProperties(Properties properties) {
this.databaseType = properties.getProperty("databaseType");
}
/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
* 其它的数据库都 没有进行分页
*
* @param page 分页对象
* @param sql 原sql语句
* @return
*/
private String getPageSql(Page<?> page, String sql) {
StringBuffer sqlBuffer = new StringBuffer(sql);
if ("mysql".equalsIgnoreCase(databaseType)) {
return getMysqlPageSql(page, sqlBuffer);
} else if ("oracle".equalsIgnoreCase(databaseType)) {
return getOraclePageSql(page, sqlBuffer);
}
return sqlBuffer.toString();
}
/**
* 获取Mysql数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Mysql中记录的位置是从0开始的。
int offset = (page.getPageNo() - 1) * page.getPageSize();
sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
return sqlBuffer.toString();
}
/**
* 获取Oracle数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Oracle数据库的分页查询语句
*/
private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
//上面的Sql语句拼接之后大概是这个样子:
//select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
return sqlBuffer.toString();
}
/**
* 给当前的参数对象page设置总记录数
*
* @param page Mapper映射语句对应的参数对象
* @param mappedStatement Mapper映射语句
* @param connection 当前的数据库连接
*/
private void setTotalRecord(Page<?> page,
MappedStatement mappedStatement, Connection connection) {
//获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
//delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
BoundSql boundSql = mappedStatement.getBoundSql(page);
//获取到我们自己写在Mapper映射语句中对应的Sql语句
String sql = boundSql.getSql();
//通过查询Sql语句获取到对应的计算总记录数的sql语句
String countSql = getCountSql(sql);
//通过BoundSql获取对应的参数映射
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(pstmt);
//之后就是执行获取总记录数的Sql语句和获取结果了。
rs = pstmt.executeQuery();
if (rs.next()) {
int totalRecord = rs.getInt(1);
//给当前的参数page对象设置总记录数
page.setTotalRecord(totalRecord);
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (rs != null)
rs.close();
if (pstmt != null)
pstmt.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
* @param sql
* @return
*/
private static String getCountSql(String sql) {
// 将sql中的from统一转小写
if(sql.indexOf("FROM") != -1) {
sql = sql.replace("FROM", "from");
}
int index = sql.indexOf("from");
return "select count(*) " + sql.substring(index);
}
// public static void main(String[] args) {
// String sql ="select * \n from *****";
// System.out.println(sql);
// System.out.println(getCountSql(sql));
// }
/**
* 利用反射进行操作的一个工具类
*
*/
private static class ReflectUtil {
/**
* 利用反射获取指定对象的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标属性的值
*/
public static Object getFieldValue(Object obj, String fieldName) {
Object result = null;
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
field.setAccessible(true);
try {
result = field.get(obj);
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
return result;
}
/**
* 利用反射获取指定对象里面的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz=obj.getClass(); clazz != Object.class; clazz=clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
//这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
}
}
return field;
}
/**
* 利用反射设置指定对象的指定属性为指定的值
* @param obj 目标对象
* @param fieldName 目标属性
* @param fieldValue 目标值
*/
public static void setFieldValue(Object obj, String fieldName,
String fieldValue) {
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
}
以上代码是分页插件,直接拷贝就能用
public class Page<T> {
private int pageNo = 1;// 页码,默认是第一页
private int pageSize = 10;// 每页显示的记录数,默认是10
private int totalRecord;// 总记录数
private int totalPage;// 总页数
private int resultSize;// 当前页数量
private List<T> results;// 对应的当前页记录
private T param; // 实际请求参数
public int getPageNo() {
return pageNo;
}
public void setPageNo(int pageNo) {
this.pageNo = pageNo;
}
public int getPageSize() {
return pageSize;
}
public void setPageSize(int pageSize) {
this.pageSize = pageSize;
}
public int getTotalRecord() {
return totalRecord;
}
public void setTotalRecord(int totalRecord) {
this.totalRecord = totalRecord;
// 在设置总记录的时候计算出对应的总页数,在下面的三目运算中加法拥有更高的优先级,所以最后可以不加括号。
int totalPage = totalRecord % pageSize == 0 ? totalRecord / pageSize : totalRecord / pageSize + 1;
this.setTotalPage(totalPage);
}
public int getTotalPage() {
return totalPage;
}
public void setTotalPage(int totalPage) {
this.totalPage = totalPage;
}
public List<T> getResults() {
return results;
}
public void setResults(List<T> results) {
this.results = results;
}
public T getParam() {
return param;
}
public void setParam(T param) {
this.param = param;
}
public int getResultSize() {
return resultSize;
}
public void setResultSize(int resultSize) {
this.resultSize = resultSize;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("Page [pageNo=").append(pageNo).append(", pageSize=").append(pageSize).append(", results=")
.append(results).append(", totalPage=").append(totalPage).append(", totalRecord=").append(totalRecord)
.append("]");
return builder.toString();
}
}
以上代码是分页插件中使用到的Page对象
下面展示如何访问
dao
public interface BaseDao<T> {
/**
* 获取单条数据
*
* @param id
* @return
*/
T get(String id);
/**
* 查询列表数据
* @param entity
* @return
*/
List<T> findList(T entity);
/**
* 分页查询
* @param page
* @return
*/
List<T> findByPage(Page<T> page);
/**
* 插入数据
*
* @param entity
* @return
*/
int insert(T entity);
/**
* 批量插入
*
* @return
*/
int batchInsert(@Param("vos") List<T> vos);
/**
* 根据多个ID查询
*
* @param ids
* @return
*/
List<T> queryByIds(@Param("ids") List<String> ids);
/**
* 更新数据
*
* @param entity
* @return
*/
int update(T entity);
/**
* 删除数据
*
* @param entity
* @return
*/
int delete(String id);
}
pojo
public abstract class PagePo extends BasePo {
/**
*
*/
private static final long serialVersionUID = -6782372722568765384L;
protected Integer pageNo;
protected Integer pageSize;
public Integer getPageNo() {
return pageNo;
}
public void setPageNo(Integer pageNo) throws PageException {
if (pageNo == null || pageNo <= 0) {
throw new PageException("页码为空或页码错误");
}
this.pageNo = pageNo;
}
public Integer getPageSize() {
return pageSize;
}
public void setPageSize(Integer pageSize) throws PageException {
if (pageSize == null || pageSize <= 0 || pageSize > 100) {
throw new PageException("每页显示数量为空或被限制");
}
this.pageSize = pageSize;
}
}
public abstract class BasePo implements Serializable {
/**
*
*/
private static final long serialVersionUID = 3477156094573872538L;
protected String id;
// 租户ID
protected String tenantId;
// 组织ID
protected String organizationId;
// 部门ID
protected String departmentId;
// 创建时间
protected Date createDate;
// 创建人
protected String createBy;
// 修改时间
protected Date updateDate;
// 修改人
protected String updateBy;
// 备注
protected String remarks;
/**
* 设置生成的ID类型。
*
* @return
*/
public abstract IdType setIdType();
/**
* 插入之前执行方法,需要手动调用
*/
public void preInsert() {
IdType idType = setIdType();
if (IdType.UUID.equals(idType)) {
setId(IdGen.uuid());
} else if (IdType.AUTO.equals(idType)) {
// 使用自增长不需要设置主键
}
this.updateDate = new Date();
this.createDate = this.updateDate;
}
/**
* 更新之前执行方法,需要手动调用
*/
public void preUpdate() {
this.updateDate = new Date();
}
/**
* 判断是否是新记录
*/
public boolean isNewRecord() {
return StringUtils.isEmpty(this.id);
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getTenantId() {
return tenantId;
}
public void setTenantId(String tenantId) {
this.tenantId = tenantId;
}
public String getOrganizationId() {
return organizationId;
}
public void setOrganizationId(String organizationId) {
this.organizationId = organizationId;
}
public String getDepartmentId() {
return departmentId;
}
public void setDepartmentId(String departmentId) {
this.departmentId = departmentId;
}
public Date getCreateDate() {
return createDate;
}
public void setCreateDate(Date createDate) {
this.createDate = createDate;
}
public String getCreateBy() {
return createBy;
}
public void setCreateBy(String createBy) {
this.createBy = createBy;
}
public Date getUpdateDate() {
return updateDate;
}
public void setUpdateDate(Date updateDate) {
this.updateDate = updateDate;
}
public String getUpdateBy() {
return updateBy;
}
public void setUpdateBy(String updateBy) {
this.updateBy = updateBy;
}
public String getRemarks() {
return remarks;
}
public void setRemarks(String remarks) {
this.remarks = remarks;
}
}
service
public abstract class CurdService<M extends BaseDao<T>, T extends PagePo> {
/**
* 持久层对象
*/
@Autowired
protected M mapper;
/**
* 获取单条数据
*
* @param id
* @return
*/
public T get(String id) {
return this.mapper.get(id);
}
/**
* 查询列表数据
*
* @param entity
* @return
*/
public List<T> findList(T entity) {
return this.mapper.findList(entity);
}
/**
* 分页查询
*
* @param param
* @return
*/
public Page<T> findByPage(T param) {
Page<T> page = new Page<>();
if (param.getPageNo() != null) {
page.setPageNo(param.getPageNo());
}
if (param.getPageSize() != null) {
page.setPageSize(param.getPageSize());
}
// 设置查询参数
page.setParam(param);
List<T> list = this.mapper.findByPage(page);
if (list == null) {
list = new ArrayList<>();
}
page.setResults(list);
page.setResultSize(list.size());
return page;
}
/**
* 保存数据(插入或更新)
*
* @param entity
*/
public void save(T entity) {
try {
// 对使用了@SpecialCharatorFilter标签的字段进行特殊字符转义
specialCharatorFilter(entity);
} catch (Exception e) {
e.printStackTrace();
}
if (entity.isNewRecord()) {
entity.preInsert();
mapper.insert(entity);
} else {
entity.preUpdate();
mapper.update(entity);
}
}
/**
* 批量保存
*
* @param entityList
*/
public void batchSave(List<T> entityList) {
mapper.batchInsert(entityList);
}
/**
* 根据ID查询列表
*
* @param ids
* @return
*/
public List<T> queryByIds(List<String> ids) {
return mapper.queryByIds(ids);
}
/**
* 删除数据
*
* @param entity
*/
public int delete(String id) {
return mapper.delete(id);
}
/**
* 特殊字符过滤
*
* @param entity
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private void specialCharatorFilter(T entity) throws IllegalArgumentException, IllegalAccessException {
// 获取当前类所有字段
Field[] fields = getAllFields(entity);
for (Field f : fields) {
if (f.isAnnotationPresent(SpecialCharatorFilter.class)) {
f.setAccessible(true);
Object value = f.get(entity);
if (!StringUtils.isEmpty(value)) {
f.set(entity, SpecialCharacterUtil.specialCharacterToText((String) value));
}
}
}
}
/**
* 获取类的所有属性,包括父类
*
* @param entity
* @return
*/
private Field[] getAllFields(T entity) {
Class<?> clazz = entity.getClass();
List<Field> fieldList = new ArrayList<>();
while (clazz != null) {
fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
clazz = clazz.getSuperclass();
}
Field[] fields = new Field[fieldList.size()];
fieldList.toArray(fields);
return fields;
}
}
controller
就不用说了吧
dao对应的xml
<select id="findByPage" resultType="co.dc.saas.code.po.CodeVariablePo">
SELECT
*
FROM code_center.code_variable
<where>
<if test="param.dataType != null">
AND data_type = #{param.dataType}
</if>
<if test="param.paramType != null and param.paramType != ''">
AND param_type = #{param.paramType}
</if>
<if test="param.paramName != null and param.paramName != ''">
AND param_name LIKE concat('%',#{param.paramName},'%')
</if>
</where>
ORDER BY sort_number ASC
</select>
注意:<if test判断语句中需要使用param点实际参数对象属性的方式,因为分页插件中传过来的参数是page对象,param是page对象的属性。