Mybatis自定义分页插件

因为项目中常常用到分页,我又嫌弃网上流行的分页插件,所以手写了一个。以后再完善说明

@JsonSerialize(using = TmsQueryJsonSerialize.class,typing = JsonSerialize.Typing.DYNAMIC)
public interface Page<T> {
    int getPageNo();
    int getPageSize();
    int getTotalCount();
    void setTotalCount(int totalCount);
    List<T> getData();
    void setData(List<T> data);
}

package com.shangqiao56;

import java.util.List;

public interface PageList<T> extends Page<T> ,List<T> {
}

/**
 * 对我们自己的page包装,让page实现List接口,ResultSetHandler 返回的必需是List类型
 * @param <T>
 */

public class TmsPageWraper<T>  implements PageList<T> {
    private TmsPage<T> page;

    public static TmsPageWraper wrap(TmsPage<?> tmsPage){
        TmsPageWraper wraper= new TmsPageWraper();
        wraper.page = tmsPage;
        return wraper;
    }

    @Override
    public int getTotalCount() {
        return page.getTotalCount();
    }

    @Override
    public void setTotalCount(int totalCount) {
        this.page.setTotalCount(totalCount);
    }

    @Override
    public int getPageSize() {
        return page.getPageSize();
    }

    @Override
    public int getPageNo() {
        return page.getPageNo();
    }

    @Override
    public List<T> getData() {
        return page.getData();
    }

    @Override
    public void setData(List<T> data) {
        page.setData(data);
    }

    @Override
    public int size() {
        return getData().size();
    }

    @Override
    public boolean isEmpty() {
        return getData().isEmpty();
    }

    @Override
    public boolean contains(Object o) {
        return getData().contains(o);
    }

    @Override
    public Iterator<T> iterator() {
        return page.iterator();
    }

    @Override
    public Object[] toArray() {
        return getData().toArray();
    }

    @Override
    public <T1> T1[] toArray(T1[] a) {
        return getData().toArray(a);
    }

    @Override
    public boolean add(T t) {
        return getData().add(t);
    }

    @Override
    public boolean remove(Object o) {
        return getData().remove(o);
    }

    @Override
    public boolean containsAll(Collection<?> c) {
        return getData().containsAll(c);
    }

    @Override
    public boolean addAll(Collection<? extends T> c) {
        return getData().addAll(c);
    }

    @Override
    public boolean addAll(int index, Collection<? extends T> c) {
        return getData().addAll(index,c);
    }

    @Override
    public boolean removeAll(Collection<?> c) {
        return getData().removeAll(c);
    }

    @Override
    public boolean retainAll(Collection<?> c) {
        return getData().retainAll(c);
    }

    @Override
    public void clear() {
        getData().clear();
    }

    @Override
    public T get(int index) {
        return getData().get(index);
    }

    @Override
    public T set(int index, T element) {
        return getData().set(index,element);
    }

    @Override
    public void add(int index, T element) {
        getData().add(index,element);
    }

    @Override
    public T remove(int index) {
        return getData().remove(index);
    }

    @Override
    public int indexOf(Object o) {
        return getData().indexOf(o);
    }

    @Override
    public int lastIndexOf(Object o) {
        return getData().lastIndexOf(o);
    }

    @Override
    public ListIterator<T> listIterator() {
        return getData().listIterator();
    }

    @Override
    public ListIterator<T> listIterator(int index) {
        return getData().listIterator();
    }

    @Override
    public List<T> subList(int fromIndex, int toIndex) {
        return getData().subList(fromIndex,toIndex);
    }
}

import com.shangqiao56.Page;
import com.shangqiao56.tms.TmsPage;
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.plugin.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.*;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * 分页插件,自动补全分页语句
 */
@Intercepts( {
        @Signature(method = "handleResultSets", type = ResultSetHandler.class, args = {Statement.class}),
        @Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class,Integer.class})})
public class PageInterceptor implements Interceptor {

    private static final Logger log = LoggerFactory.getLogger(PageInterceptor.class);

    private static final ThreadLocal<Page> pageContext = new ThreadLocal<>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object target = invocation.getTarget();
        Object result = null;
        TmsPageWraper pageWraper;

        if(target instanceof StatementHandler){
            StatementHandler statementHandler = (StatementHandler) target;
            Object queryParams = statementHandler.getParameterHandler().getParameterObject();
            TmsPage page = parsePage(queryParams);
            if(page != null){
                /***
                 * 真有必要将page拆分成两个接口
                 */
                pageWraper = TmsPageWraper.wrap(page);
                pageContext.set(pageWraper);

                Connection connection = (Connection)invocation.getArgs()[0];
                /***
                 * //0表示不限制,参考
                 * @see Statement.setQueryTimeout
                 */
                int timeout =  (invocation.getArgs()[1]!=null) ? (Integer)invocation.getArgs()[1]:0;

                String sql = statementHandler.getBoundSql().getSql();

                result = prepare(connection,sql,timeout,page);

                //填充totalCount的值
                String countSql = buildQueryCount(statementHandler.getBoundSql().getSql());
                log.debug("generate a query count sql : {}",countSql);
                PreparedStatement coutStat =connection.prepareStatement(countSql);
                statementHandler.parameterize(coutStat);
                coutStat.setQueryTimeout(timeout);
                ResultSet rs = coutStat.executeQuery();
                if(rs.next()) {
                    Integer totalCount = rs.getInt(1);
                    page.setTotalCount(totalCount);
                }

            }else{
                result = invocation.proceed();
            }

        }else if(target instanceof ResultSetHandler){
            pageWraper = (TmsPageWraper) pageContext.get();
            result = invocation.proceed();
            if(pageWraper != null){
                pageWraper.setData((List)result);
                pageContext.remove();
                result = pageWraper;
            }

        }
        return result;
    }


    private Statement prepare(Connection connection,String querySql, Integer timeout, TmsPage<?> page) throws SQLException {
        //  BoundSql bound = statementHandler.getBoundSql();
        String sql = buildPageSql(querySql,page);

        //todo 我这边少了很多检查,去看 PreparedStatementHandler
        ErrorContext.instance().sql(sql);

        Statement statement = connection.prepareStatement(sql);
        statement.setQueryTimeout(timeout);
        return statement;
    }

    private TmsPage parsePage(Object parameter){
        TmsPage page = null;
        if(parameter instanceof Page){
            page = (TmsPage<?>) parameter;
        }
        if(parameter instanceof Map){
            log.debug("query params {}",parameter);
            Map<String,Object> paraMap  = (Map)parameter;
            for (Map.Entry<String, Object> entry : paraMap.entrySet()) {
                if(entry.getValue() instanceof TmsPage)
                    page = (TmsPage)entry.getValue();
            }
        }

        return page;
    }




    private  static final String PAGE_SQL = " SELECT * FROM (SELECT A.*, ROWNUM RN FROM (%s) A WHERE  ROWNUM <= %d) WHERE RN > %d";

    protected String buildPageSql(String sql,TmsPage<?> page){
        return String.format(PAGE_SQL,sql,page.getEndRow(),page.getStartRow());
    }

    private String buildQueryCount(String sql){
        int fromIdx = getFromIdx(sql);
        return "SELECT COUNT(1) " +sql.substring(fromIdx);
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler || target instanceof ResultSetHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {

    }


    /***
     * 获取sql语句中的from中的位置
     * @param sql
     * @return
     */
    private static int getFromIdx(String sql) {
        char[] chars = sql.toCharArray();
        char c;
        int bcount=0;
        final String newWordstars = " )\n\t";  //from开始前能出现的单词

        for (int i = 0; i < chars.length; i++) {
            c = chars[i];
            if(c == '(') bcount++;
            else if(c == ')') bcount --;

            if(bcount != 0) continue;
            if(newWordstars.indexOf(c)!=-1){
                if((chars[i+1]=='f'||chars[i+1]=='F')
                        &&(chars[i+2]=='r'||chars[i+2]=='R')
                        &&(chars[i+3]=='o'||chars[i+3]=='O')
                        &&(chars[i+4]=='m'||chars[i+4]=='M')
                        &&(chars[i+5]==' ' || chars[i+5]=='(' || chars[i+5]=='\n') ){
                    return i;
                }
            }
        }
        return  -1;
    }


}

@Configuration
@ConditionalOnClass({SqlSessionFactory.class, SqlSessionFactoryBean.class})
@ConditionalOnSingleCandidate(DataSource.class)
@EnableConfigurationProperties({MybatisProperties.class})
@AutoConfigureAfter({DataSourceAutoConfiguration.class})
public class TmsMybatisAutoConfiguration implements InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(TmsMybatisAutoConfiguration.class);
    private final MybatisProperties properties;
    private final Interceptor[] interceptors;
    private final ResourceLoader resourceLoader;
    private final DatabaseIdProvider databaseIdProvider;
    private final List<ConfigurationCustomizer> configurationCustomizers;

    public TmsMybatisAutoConfiguration(MybatisProperties properties, ObjectProvider<Interceptor[]> interceptorsProvider, ResourceLoader resourceLoader, ObjectProvider<DatabaseIdProvider> databaseIdProvider, ObjectProvider<List<ConfigurationCustomizer>> configurationCustomizersProvider) {
        this.properties = properties;
        this.interceptors = (Interceptor[])interceptorsProvider.getIfAvailable();
        this.resourceLoader = resourceLoader;
        this.databaseIdProvider = (DatabaseIdProvider)databaseIdProvider.getIfAvailable();
        this.configurationCustomizers = (List)configurationCustomizersProvider.getIfAvailable();
    }

    public void afterPropertiesSet() {
        this.checkConfigFileExists();
    }

    private void checkConfigFileExists() {
        if (this.properties.isCheckConfigLocation() && StringUtils.hasText(this.properties.getConfigLocation())) {
            Resource resource = this.resourceLoader.getResource(this.properties.getConfigLocation());
            Assert.state(resource.exists(), "Cannot find config location: " + resource + " (please add config file or check your Mybatis configuration)");
        }

    }

    @Bean
    @ConditionalOnMissingBean
    public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
        SqlSessionFactoryBean factory = new SqlSessionFactoryBean();
        factory.setDataSource(dataSource);
        factory.setVfs(SpringBootVFS.class);
        if (StringUtils.hasText(this.properties.getConfigLocation())) {
            factory.setConfigLocation(this.resourceLoader.getResource(this.properties.getConfigLocation()));
        }

        this.applyConfiguration(factory);
        if (this.properties.getConfigurationProperties() != null) {
            factory.setConfigurationProperties(this.properties.getConfigurationProperties());
        }

        if (!ObjectUtils.isEmpty(this.interceptors)) {
            factory.setPlugins(this.interceptors);
        }

        if (this.databaseIdProvider != null) {
            factory.setDatabaseIdProvider(this.databaseIdProvider);
        }

        if (StringUtils.hasLength(this.properties.getTypeAliasesPackage())) {
            factory.setTypeAliasesPackage(this.properties.getTypeAliasesPackage());
        }

        if (this.properties.getTypeAliasesSuperType() != null) {
            factory.setTypeAliasesSuperType(this.properties.getTypeAliasesSuperType());
        }

        if (StringUtils.hasLength(this.properties.getTypeHandlersPackage())) {
            factory.setTypeHandlersPackage(this.properties.getTypeHandlersPackage());
        }

        if (!ObjectUtils.isEmpty(this.properties.resolveMapperLocations())) {
            factory.setMapperLocations(this.properties.resolveMapperLocations());
        }

        return factory.getObject();
    }

    private void applyConfiguration(SqlSessionFactoryBean factory) {
        org.apache.ibatis.session.Configuration configuration = this.properties.getConfiguration();
        if (configuration == null && !StringUtils.hasText(this.properties.getConfigLocation())) {
            configuration = new org.apache.ibatis.session.Configuration();
        }

        if (configuration != null && !CollectionUtils.isEmpty(this.configurationCustomizers)) {
            Iterator var3 = this.configurationCustomizers.iterator();

            while(var3.hasNext()) {
                ConfigurationCustomizer customizer = (ConfigurationCustomizer)var3.next();
                customizer.customize(configuration);
            }
        }

        factory.setConfiguration(configuration);
    }

    @Bean
    @ConditionalOnMissingBean
    public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
        ExecutorType executorType = this.properties.getExecutorType();
        return executorType != null ? new SqlSessionTemplate(sqlSessionFactory, executorType) : new SqlSessionTemplate(sqlSessionFactory);
    }

    @Configuration
    @Import({TmsMybatisAutoConfiguration.AutoConfiguredMapperScannerRegistrar.class})
    @ConditionalOnMissingBean({MapperFactoryBean.class})
    public static class MapperScannerRegistrarNotFoundConfiguration implements InitializingBean {
        public MapperScannerRegistrarNotFoundConfiguration() {
        }

        public void afterPropertiesSet() {
            TmsMybatisAutoConfiguration.logger.debug("No {} found.", MapperFactoryBean.class.getName());
        }
    }

    public static class AutoConfiguredMapperScannerRegistrar implements BeanFactoryAware, ImportBeanDefinitionRegistrar, ResourceLoaderAware {
        private BeanFactory beanFactory;
        private ResourceLoader resourceLoader;

        public AutoConfiguredMapperScannerRegistrar() {
        }

        public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
            if (!AutoConfigurationPackages.has(this.beanFactory)) {
                TmsMybatisAutoConfiguration.logger.debug("Could not determine auto-configuration package, automatic mapper scanning disabled.");
            } else {
                TmsMybatisAutoConfiguration.logger.debug("Searching for mappers annotated with @Mapper");
                List<String> packages = AutoConfigurationPackages.get(this.beanFactory);
                if (TmsMybatisAutoConfiguration.logger.isDebugEnabled()) {
                    packages.forEach((pkg) -> {
                        TmsMybatisAutoConfiguration.logger.debug("Using auto-configuration base package '{}'", pkg);
                    });
                }

                ClassPathMapperScanner scanner = new ClassPathMapperScanner(registry);
                if (this.resourceLoader != null) {
                    scanner.setResourceLoader(this.resourceLoader);
                }

                scanner.setAnnotationClass(Mapper.class);
                scanner.registerFilters();
                scanner.doScan(StringUtils.toStringArray(packages));
            }
        }

        public void setBeanFactory(BeanFactory beanFactory) {
            this.beanFactory = beanFactory;
        }

        public void setResourceLoader(ResourceLoader resourceLoader) {
            this.resourceLoader = resourceLoader;
        }
    }
}

使用方法

        TmsPage page = new TmsPage(1,10);
        List list =carDao.queryByType(page,"大货车");
        list.forEach(System.out::println);
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值