使用mybatis 提供的拦截器实现mybatis 自定义分页查询。有一定的缺陷,不能把对象当成参数 传进来。
主要实现方案:
要实现Interceptor接口,并且加上 @Intercepts 注解, 拦截器 会拦截 四个核心实现类,
Executor、StatementHandler、PameterHandler和ResultSetHandler 接口进行拦截,也就是说会对这4种对象进行代理
Executor 核心处理器:所有执行sql 语句都会执行这个类进行分发;
StatementHandler:mybatis 对 statement 进行的封装 主要用于发送sql 使用
PameterHandler:mybatis 对 pameter进行封装,主要是是使用预编译,参数
ResultSetHandler:mybatis 对于resultset的封装。对结果处理
拦截器都可以在每个方法执行之前进行处理。
开始进行分页插件开发:
首先 要进行的是: 要实现接口;Interceptor
重写三个方法:
intercept 拦截器的具体执行方法,
plugin:拦截那些方法:
setProperties:得到配置文件信息
plugin----->我们只拦截 Executor 方法 所以这样写
if (arg0 instanceof Executor) {
return Plugin.wrap(arg0, this);
} else {
return arg0;
}
只拦截 Executor 其他的方法放行
得到 对应的 配置文件信息,得到数据库类型
我们拦截这个方法 ,所以:注解这么写。参数列表写好
@Intercepts({
@Signature(method = "拦截的方法名字", type = 类的字节码文件, args = {参数列表)})
注释很清楚 不一一赘述了
查看是否存在 page对象,有page对象的 才进行拦截
要在配置文件中 记得加载插件
在springboot 中加载配置文件
全部代码如下 我在提供一个 csdn 下载地址
https://download.csdn.net/download/drsbbbl/11710549
//@Component
@Intercepts({
@Signature(method = "query", type = Executor.class, args = { MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class})})
public class Test implements Interceptor {
private String sqlType;
@Override
public Object intercept(Invocation arg0) throws Throwable {
// TODO Auto-generated method stub
//得到对应的执行真正执行sql 的类。
//Object fieldValue = getFieldValue(arg0,"target");
//Object fieldValue2 = getFieldValue(fieldValue,"h");
//RoutingStatementHandler fieldValue3 = (RoutingStatementHandler) getFieldValue(fieldValue2,"target");
//Executor target = (Executor) arg0.getTarget();
//RoutingStatementHandler target1=(RoutingStatementHandler) target;
//绑定sql 获得 RoutingStatementHandler 内部维护的StatementHandler真正加载sql的类
//StatementHandler delegate=(StatementHandler) getFieldValue(fieldValue3,"delegate");
//target.getParameterHandler()
//PageHelper.startPage(0, 1)
MappedStatement mappedStatement = (MappedStatement)arg0.getArgs()[0];
//只拦截select 方法
if(SqlCommandType.SELECT==mappedStatement.getSqlCommandType()){
//等待获取参数列表
Object obj=arg0.getArgs()[1];
BoundSql boundSql = mappedStatement.getBoundSql(obj);
//得到参数列表
Object parameterObject = boundSql.getParameterObject();
//写了注解的参数列表
//List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
//查看参数对象 是不是有 对应的 分页对象
//把 参数看看是否是集合
if(parameterObject instanceof Map<?,?>){
//存在分页对象 把分页对象 数据取出来 进行分页统计
Map<String,Object> map=(Map<String, Object>) parameterObject;
Page<?> existValueInMap = isExistValueInMap(map);
if(existValueInMap!=null){
//关于sql 的全部配置全在 它里面
//拦截到的prepare方法参数是一个Connection对象
String sql = boundSql.getSql();
// mappedStatement 得到分页sql 进行查询数据总和
Connection connection = getConnect(mappedStatement);
String countSql =this.getCountSql(sql);
//获得 sql解析对象
RawSqlSource rawSqlSource = (RawSqlSource) getFieldValue(mappedStatement,"sqlSource");
StaticSqlSource staticSqlSource = (StaticSqlSource) getFieldValue(rawSqlSource, "sqlSource");
//得到参数列表 property
List<ParameterMapping> parameterMappings = (List<ParameterMapping>) getFieldValue(staticSqlSource,"parameterMappings");
if(existValueInMap.isFlag()){
//顺序property 特别重要
//进行执行查询总条数的sql 执行 并且set进 page对象
//this.setTotalRecord(existValueInMap, countSql, parameterMappings,map,connection);
this.setTotalRecord(existValueInMap, mappedStatement, connection);
}
//获取分页Sql语句
String pageSql = this.getPageSql(existValueInMap, sql);
//利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
setFieldValue(staticSqlSource, "sql", pageSql);
}
}
}
return arg0.proceed();
}
public Connection getConnect(MappedStatement mappedStatement ){
DataSource dataSource = mappedStatement.getConfiguration().getEnvironment().getDataSource();
try {
return dataSource.getConnection();
} catch (SQLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return null;
}
/**
*
* 验证map中是否存在特定的value 并返回
* @param map 验证的map
* @param obj 特定的value
* @return
*/
public static Page<?> isExistValueInMap(Map<String,Object> map){
//param1
if(map!=null && map.size()>0){
for (int i = 1; i <= map.size(); i++) {
Object object = map.get("param"+i);
if(object instanceof Page<?>){
return (Page<?>) object;
}
}
}
return null;
}
/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
* 其它的数据库都 没有进行分页
*
* @param page 分页对象
* @param sql 原sql语句
* @return
*/
private String getPageSql(Page<?> page, String sql) {
StringBuffer sqlBuffer = new StringBuffer(sql);
if ("mysql".equalsIgnoreCase(sqlType)) {
return getMysqlPageSql(page, sqlBuffer);
} else if ("oracle".equalsIgnoreCase(sqlType)) {
return getOraclePageSql(page, sqlBuffer);
}
return sqlBuffer.toString();
}
/**
* 获取Mysql数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Mysql中记录的位置是从0开始的。
int offset = (page.getPageNo() - 1) * page.getPageSize();
sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
return sqlBuffer.toString();
}
/**
* 获取Oracle数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Oracle数据库的分页查询语句
*/
private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
//上面的Sql语句拼接之后大概是这个样子:
//select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
return sqlBuffer.toString();
}
/**
* 给当前的参数对象page设置总记录数
*
* @param page Mapper映射语句对应的参数对象
* @param mappedStatement Mapper映射语句
* @param connection 当前的数据库连接
*/
private void setTotalRecord(Page<?> page,
MappedStatement mappedStatement, Connection connection) {
//获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
//delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
BoundSql boundSql = mappedStatement.getBoundSql(page);
//获取到我们自己写在Mapper映射语句中对应的Sql语句
String sql = boundSql.getSql();
//通过查询Sql语句获取到对应的计算总记录数的sql语句
String countSql = this.getCountSql(sql);
//通过BoundSql获取对应的参数映射
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(pstmt);
//之后就是执行获取总记录数的Sql语句和获取结果了。
rs = pstmt.executeQuery();
if (rs.next()) {
Integer totalRecord = rs.getInt(1);
//给当前的参数page对象设置总记录数
page.setPageCount(totalRecord);
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (rs != null)
rs.close();
if (pstmt != null)
pstmt.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
* @param sql
* @return
*/
private String getCountSql(String sql) {
/* select * from( select
*
from table left ...where )
*/
String lowerCase = sql.toLowerCase();
StringBuilder sb=new StringBuilder();
int indexOf = lowerCase.indexOf("select");
int indexOf2 = lowerCase.indexOf("from");
sb.append(lowerCase.substring(0, indexOf+6))
.append(" count(1) ")
.append(lowerCase.substring(indexOf2, sql.length()));
return sb.toString();
}
/**
* 只拦截 Executor 其他的放行
*
*/
@Override
public Object plugin(Object arg0) {
// TODO Auto-generated method stub
if (arg0 instanceof Executor) {
return Plugin.wrap(arg0, this);
} else {
return arg0;
}
}
@Override
public void setProperties(Properties arg0) {
// TODO Auto-generated method stub
this.sqlType=arg0.getProperty("sqlType");
}
/**
*
* @param obj
* @param fieldName
* @return
*/
private static Object getFieldValue(Object obj,String fieldName){
//处理 第一步 可能obj为空
if(obj==null){
return new RuntimeException("null ponit exception 传入的对象为空");
}
if(fieldName == null || fieldName.length()==0){
return new RuntimeException("fieldName is null ponit exception 传入的字段不能为空");
}
//获得当前对象class
Object object=null;
// Class<? extends Object> class1 = obj.getClass();
Field declaredField;
try {
declaredField=getField(obj,fieldName);
declaredField.setAccessible(true);
object = declaredField.get(obj);
} catch (SecurityException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
return new RuntimeException("参数非法异常 传入的对象 非法");
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return object;
}
/**
* 利用反射设置指定对象的指定属性为指定的值
* @param obj 目标对象
* @param fieldName 目标属性
* @param fieldValue 目标值
*/
public static void setFieldValue(Object obj, String fieldName,
String fieldValue) {
Field field = getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
/**
* 利用反射获取指定对象里面的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz=obj.getClass(); clazz != Object.class; clazz=clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
//这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
}
}
return field;
}
}
package com.neo.model;
/**
* 自动查询 分页数据的实体类
* 会自动统计分页数据 自动查询count(1)
* @author mbb
*
* @param <T>
*/
public class Page<T> {
/**
* 当前页
*/
private int pageNo=1;
/**
* 页的大小 一页的数据
*/
private int pageSize=Integer.MAX_VALUE;
/**
* 总数量
*/
private Integer pageCount;
/**
* 是否自动统计 当前总数量 默认为会加载 总数量 可以进行自行关闭
*/
private boolean flag=true;
public int getPageNo() {
return pageNo;
}
public void setPageNo(int pageNo) {
this.pageNo = pageNo;
}
public int getPageSize() {
return pageSize;
}
public void setPageSize(int pageSize) {
this.pageSize = pageSize;
}
public Integer getPageCount() {
return pageCount;
}
public void setPageCount(Integer pageCount) {
this.pageCount = pageCount;
}
public boolean isFlag() {
return flag;
}
public void setFlag(boolean flag) {
this.flag = flag;
}
}