Sharding-JDBC-5.1.0 实现按月分表、动态建表

1、引入Maven 依赖

        <dependency>
            <groupId>org.apache.shardingsphere</groupId>
            <artifactId>shardingsphere-jdbc-core-spring-boot-starter</artifactId>
            <version>5.1.0</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>druid</artifactId>
            <version>1.2.8</version>
        </dependency>
        <dependency>
            <groupId>org.apache.tomcat</groupId>
            <artifactId>tomcat-dbcp</artifactId>
            <version>10.0.16</version>
        </dependency>

2、yml 配置文件

spring:
  sharding-sphere:
    datasource:
      names: master
      master:
        type: com.alibaba.druid.pool.DruidDataSource
        driver-class-name: org.postgresql.Driver
        url: jdbc:postgresql://127.0.0.1:5432/production_dev?serverTimezone=UTC&characterEncoding=utf-8&stringtype=unspecified
        username: postgres
        password: 123456
    rules:
      sharding:
        sharding-algorithms:
          month-sharding-algorithm:
            props:
              strategy: standard
              algorithmClassName: com.base.shading.CreateTimeShardingAlgorithm
            type: CLASS_BASED
        tables:
          work_procedure:
#            actual-data-nodes: master.work_result_${2023..2050}_0${1..9},master.work_result_${2023..2050}_${10..12}
            actual-data-nodes: master.$->{com.base.shading.ShardingAlgorithmTool.cacheTableNames()}
            table-strategy:
              standard:
                sharding-column: create_time
                sharding-algorithm-name: month-sharding-algorithm
          work_result:
#            actual-data-nodes: master.work_result_${2023..2050}_0${1..9},master.work_result_${2023..2050}_${10..12}
            actual-data-nodes: master.$->{com.base.shading.ShardingAlgorithmTool.cacheTableNames()}
            table-strategy:
              standard:
                sharding-column: create_time
                sharding-algorithm-name: month-sharding-algorithm
        bindingTables:
          - work_procedure, work_result
    props:
      sql-show: true

3、分片算法类 CreateTimeShardingAlgorithm.java

package com.base.shading;

import com.google.common.collect.Range;
import org.apache.shardingsphere.sharding.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.StandardShardingAlgorithm;

import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.*;

public class CreateTimeShardingAlgorithm implements StandardShardingAlgorithm<Date> {

        private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");


    /**
     * 精准分片
     * @param collection 对应分片库中所有分片表的集合
     * @param preciseShardingValue 分片键值,其中 logicTableName 为逻辑表,columnName 分片键,value 为从 SQL 中解析出来的分片键的值
     * @return 表名
     */
    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<Date> preciseShardingValue) {
        Object value = preciseShardingValue.getValue();
        String tableSuffix = null;
        if(value instanceof Date){
            //Date columnValue = (Date) value;
            //tableSuffix = columnValue.toInstant().atZone(ZoneId.systemDefault()).format(DateTimeFormatter.ofPattern("yyyy_MM"));
            LocalDate localDate = ((Date) value).toInstant().atZone(ZoneId.systemDefault()).toLocalDate();
            tableSuffix = localDate.format(DateTimeFormatter.ofPattern("yyyy_MM"));
        }else{
            String column = (String)value;
            tableSuffix = LocalDateTime.parse(column, formatter).format(DateTimeFormatter.ofPattern("yyyy_MM"));
        }
        String logicTableName = preciseShardingValue.getLogicTableName();
        String actualTableName = logicTableName.concat("_").concat(tableSuffix);
                if(!collection.contains(actualTableName)){
            collection.add(actualTableName);
        }
return ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(logicTableName, actualTableName);
    }

    /**
     * 范围分片
     * @param collection 对应分片库中所有分片表的集合
     * @param rangeShardingValue 分片范围
     * @return 表名集合
     */
    @Override
    public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<Date> rangeShardingValue) {
        // 逻辑表名
        String logicTableName = rangeShardingValue.getLogicTableName();
        // 范围参数
        Range<Date> valueRange = rangeShardingValue.getValueRange();
        //起始时间  结束时间
        LocalDateTime start = null;
        LocalDateTime end = null;
        Object lowerEndpoint = (Object)valueRange.lowerEndpoint();
        Object upperEndpoint = (Object)valueRange.upperEndpoint();
        if(lowerEndpoint instanceof  String){
            String lower = (String) lowerEndpoint;
            String upper = (String) upperEndpoint;
            start = LocalDateTime.parse(lower,formatter);
            end = LocalDateTime.parse(upper,formatter);
        }else{
             start = valueRange.lowerEndpoint().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
             end = valueRange.upperEndpoint().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
        }
        if(end.isAfter(LocalDateTime.now())){
            end = LocalDateTime.now();
        }
        // 查询范围的表
        Set<String> queryRangeTables = extracted(logicTableName, start, end);
        // 数据库中的表
        HashSet<String> tableNameSet = ShardingAlgorithmTool.cacheTableNames();
        //
        if(collection.size() != tableNameSet.size()){
            collection.clear();
            collection.addAll(tableNameSet);
        }
        //返回的表
        ArrayList<String> tables = new ArrayList<>(tableNameSet);
        tables.retainAll(queryRangeTables);
        return tables;
    }

    @Override
    public String getType() {
        return null;
    }

    @Override
    public Properties getProps() {
        return null;
    }

    @Override
    public void init() {

    }

    /**
     * 根据范围计算表名
     *
     * @param logicTableName 逻辑表名
     * @param lowerEndpoint 范围起点
     * @param upperEndpoint 范围终端
     * @return 物理表名集合
     */
    private Set<String> extracted(String logicTableName, LocalDateTime lowerEndpoint, LocalDateTime upperEndpoint) {
        Set<String> rangeTable = new HashSet<>();
        while (lowerEndpoint.isBefore(upperEndpoint)) {
            String str = getTableNameByDate(lowerEndpoint, logicTableName);
            rangeTable.add(str);
            lowerEndpoint = lowerEndpoint.plusMonths(1);
        }
        // 获取物理表名
        String tableName = getTableNameByDate(upperEndpoint, logicTableName);
        rangeTable.add(tableName);
        return rangeTable;
    }
    /**
     * 根据日期获取表名
     * @param dateTime 日期
     * @param logicTableName 逻辑表名
     * @return 物理表名
     */
    private String getTableNameByDate(LocalDateTime dateTime, String logicTableName) {
        String tableSuffix = dateTime.format(DateTimeFormatter.ofPattern("yyyy_MM"));
        return logicTableName.concat("_").concat(tableSuffix);
    }
}

4、 分片工具类  ShardingAlgorithmTool.java

package com.base.shading;

import com.base.event.UnisicBaseEvent;
import com.base.utils.ApplicationContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.driver.jdbc.core.datasource.ShardingSphereDataSource;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.*;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.shardingsphere.infra.config.RuleConfiguration;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.sharding.algorithm.config.AlgorithmProvidedShardingRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.ShardingRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.rule.ShardingTableRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.ShardingStrategyConfiguration;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.interceptor.TransactionAspectSupport;

import static com.unisic.base.shading.ConstantInterface.CREATE_INDEX;

@Slf4j
@Component
public class ShardingAlgorithmTool {

    private static final String logicDb = "logic_db";
    @Resource
    private ShardingSphereDataSource shardingSphereDataSource;

    @Autowired
    private ApplicationContext applicationContext;
    private static ApplicationContext context;

    private static  ShardingSphereDataSource shardingDataSource;
    private static final HashSet<String> tableNameCache = new HashSet<>();
    private static final List<String> ShardingTableNames = Arrays.asList("work_procedure","work_manage","work_procedure_file","work_result");

    // 启动时,实际表中要有值,启动后,在ShardingTablesLoadRunner中先清空在缓存
    static  {
        ShardingTableNames.forEach(item->{
            String tableSuffix = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy_MM"));
            String actualTableName = item.concat("_").concat(tableSuffix);
            tableNameCache.add(actualTableName);
        });
    }

    @PostConstruct
    public void init() {
        shardingDataSource = shardingSphereDataSource;
        context = applicationContext;
    }
    /**
     * 获取所有表名
     * @return 表名集合
     */
    public static List<String> getAllTableNameBySchema() {
        List<String> tableNames = new ArrayList<>();
        String sql = "SELECT tablename FROM pg_tables  WHERE schemaname = 'public' AND tablename ~ '_\\d{4}_\\d{2}$';";
        DataSource dataSource = ApplicationContextUtil.getBean(DataSource.class);
        try (Connection connection = dataSource.getConnection()) {
            Statement statement = connection.createStatement();
            try (ResultSet rs = statement.executeQuery(sql)) {
                while (rs.next()) {
                    String actualTableName = rs.getString(1);
                    tableNames.add(actualTableName);
                }
            }
        } catch (SQLException e) {
            log.info("SQLException: " + e.getMessage());
            throw new RuntimeException(e);
        }
        return tableNames;
    }


    /**
     *
     * 判断 分表获取的表名是否存在 不存在则自动建表
     *
     * @param logicTableName  逻辑表名(表头)
     * @param actualTableName 真实表名
     * @return 确认存在于数据库中的真实表名
     */
    public static String shardingTablesCheckAndCreatAndReturn(String logicTableName, String actualTableName) {
        synchronized (logicTableName.intern()) {
            // 缓存中有此表 返回
            if (tableNameCache.contains(actualTableName)) {
                return actualTableName;
            }
            // 建表sql
            String createTableSql="CREATE TABLE "+ actualTableName +" (LIKE "+ logicTableName +" );";
            // 主键sql
            String createPrimaryKeySql="ALTER TABLE " + actualTableName + " ADD CONSTRAINT "+ actualTableName +"_pk PRIMARY KEY (id);";
            // 缓存中无此表,则建表,创建索引并添加缓存
            DataSource dataSource = ApplicationContextUtil.getBean(DataSource.class);
            try {
                Connection connection = dataSource.getConnection();
                Statement statement = connection.createStatement();
                try {
                    statement.executeUpdate(createTableSql);
                } catch (SQLException e) {
                    log.info("create-table-SQLException: " + e.getMessage());
                    throw new RuntimeException(e);
                }
                try {
                    statement.executeUpdate(createPrimaryKeySql);
                } catch (SQLException e) {
                    log.info("create-primaryKey-SQLException: " + e.getMessage());
                    throw new RuntimeException(e);
                }
                ArrayList<String> names = new ArrayList<>(2);
                names.add(logicTableName);
                names.add(actualTableName);
                context.publishEvent(new UnisicBaseEvent<>(names,ConstantInterface.CREATE_INDEX));
            } catch (SQLException e) {
                log.info("SQLException: " + e.getMessage());
                throw new RuntimeException(e);
            }

            // 缓存重载
            tableNameCacheReload(true);

        }
        return actualTableName;
    }

    /**
     * 缓存重载方法
     */
    public static void tableNameCacheReload(boolean flag) {
        // 读取数据库中所有表名
        List<String> tableNameList = getAllTableNameBySchema();
        // 删除旧的缓存(如果存在)
        ShardingAlgorithmTool.tableNameCache.clear();
        // 写入新的缓存
        ShardingAlgorithmTool.tableNameCache.addAll(tableNameList);
        if(flag){
            reloadShardRuleActualDataNodes(shardingDataSource,logicDb);
        }

    }

    /**
     * 获取缓存中的表名
     * @return
     */
    public static HashSet<String> cacheTableNames() {
        return tableNameCache;
    }

    /**
     * 根据行表达式生成表名
     * @param expression
     * @return
     */
    public static List<String> generateTableNames(String expression) {
        List<String> tableNames = new ArrayList<>();
        List<String> returnTableNames = new ArrayList<>();
        List<String> extractedValues = extractValues(expression);
        if(expression.contains(".")){
            expression = expression.substring(expression.indexOf(".")+1);
        }
        String replacedExpression = expression.replaceAll("\\$\\{.*?}", "#");

        for(int i=0; i < extractedValues.size(); i++ ){
            String s = extractedValues.get(i).replaceAll("\\.", "#");
            String[] split = s.split("##");
            int start = Integer.parseInt(split[0]);
            int end = Integer.parseInt(split[1]);
            if(i == 0){
                for(int j=start; j <= end; j++){
                    String replace = replacedExpression.replaceFirst("#", String.valueOf(j));
                    tableNames.add(replace);
                }
            }else{
                for(int k=0; k < tableNames.size(); k++){
                    for(int j=start; j <= end; j++){
                        String s1 = tableNames.get(k).replaceFirst("#", String.valueOf(j));
                        returnTableNames.add(s1);

                    }
                }
            }
        }
        return returnTableNames;
    }

    /**
     * 提取行表达式中的${}中范围值
     * @param expression 行表达式
     * @return
     */
    private static List<String> extractValues(String expression) {
        List<String> extractedValues = new ArrayList<>();

        Pattern pattern = Pattern.compile("\\$\\{([^}]*)\\}");
        Matcher matcher = pattern.matcher(expression);

        while (matcher.find()) {
            String extractedValue = matcher.group(1);
            extractedValues.add(extractedValue);
        }

        return extractedValues;
    }

    private static void reloadShardRuleActualDataNodes(ShardingSphereDataSource dataSource, String schemaName) {
        // Context manager.
        org.apache.shardingsphere.mode.manager.ContextManager contextManager = dataSource.getContextManager();
       // Rule configuration.
        Collection<RuleConfiguration> oldRuleConfigList = dataSource.getContextManager()
                .getMetaDataContexts()
                .getMetaData(schemaName)
                .getRuleMetaData()
                .getConfigurations();
        for (RuleConfiguration config : oldRuleConfigList) {
            Collection<ShardingTableRuleConfiguration> tables = ((AlgorithmProvidedShardingRuleConfiguration) config).getTables();
            tables.forEach(logicTable->{
                String logicTableName = logicTable.getLogicTable();
                if(ShardingTableNames.contains(logicTableName)) {
                    setActualDataNodes(logicTable,"master.$->{com.unisic.base.shading.ShardingAlgorithmTool.cacheTableNames()}");
                }
            });
        }
        contextManager.alterRuleConfiguration(logicDb, oldRuleConfigList);
//        Collection<RuleConfiguration> newRuleConfigList = new LinkedList<>();
//        for (RuleConfiguration oldRuleConfig : oldRuleConfigList) {
//            if (oldRuleConfig instanceof AlgorithmProvidedShardingRuleConfiguration) {
//
//                // Algorithm provided sharding rule configuration
//                AlgorithmProvidedShardingRuleConfiguration oldAlgorithmConfig = (AlgorithmProvidedShardingRuleConfiguration) oldRuleConfig;
//                AlgorithmProvidedShardingRuleConfiguration newAlgorithmConfig = new AlgorithmProvidedShardingRuleConfiguration();
//
//                // Sharding table rule configuration Collection
//                Collection<ShardingTableRuleConfiguration> newTableRuleConfigList = new LinkedList<>();
//                Collection<ShardingTableRuleConfiguration> oldTableRuleConfigList = oldAlgorithmConfig.getTables();
//
//                oldTableRuleConfigList.forEach(oldTableRuleConfig -> {
//                    if (tableNameCache.contains(oldTableRuleConfig.getLogicTable())) {
//                        ShardingTableRuleConfiguration newTableRuleConfig = new ShardingTableRuleConfiguration(oldTableRuleConfig.getLogicTable(), "master.$->{com.unisic.base.shading.ShardingAlgorithmTool.cacheTableNames()}");
//                        newTableRuleConfig.setTableShardingStrategy(oldTableRuleConfig.getTableShardingStrategy());
//                        newTableRuleConfig.setDatabaseShardingStrategy(oldTableRuleConfig.getDatabaseShardingStrategy());
//                        newTableRuleConfig.setKeyGenerateStrategy(oldTableRuleConfig.getKeyGenerateStrategy());
//
//                        newTableRuleConfigList.add(newTableRuleConfig);
//                    } else {
//                        newTableRuleConfigList.add(oldTableRuleConfig);
//                    }
//                });
//
//                newAlgorithmConfig.setTables(newTableRuleConfigList);
//                newAlgorithmConfig.setAutoTables(oldAlgorithmConfig.getAutoTables());
//                newAlgorithmConfig.setBindingTableGroups(oldAlgorithmConfig.getBindingTableGroups());
//                newAlgorithmConfig.setBroadcastTables(oldAlgorithmConfig.getBroadcastTables());
//                newAlgorithmConfig.setDefaultDatabaseShardingStrategy(oldAlgorithmConfig.getDefaultDatabaseShardingStrategy());
//                newAlgorithmConfig.setDefaultTableShardingStrategy(oldAlgorithmConfig.getDefaultTableShardingStrategy());
//                newAlgorithmConfig.setDefaultKeyGenerateStrategy(oldAlgorithmConfig.getDefaultKeyGenerateStrategy());
//                newAlgorithmConfig.setDefaultShardingColumn(oldAlgorithmConfig.getDefaultShardingColumn());
//                newAlgorithmConfig.setShardingAlgorithms(oldAlgorithmConfig.getShardingAlgorithms());
//                newAlgorithmConfig.setKeyGenerators(oldAlgorithmConfig.getKeyGenerators());
//
//                newRuleConfigList.add(newAlgorithmConfig);
//            }
//        }
//
//        // update context
//        contextManager.alterRuleConfiguration(schemaName, newRuleConfigList);

    }


    private static void setActualDataNodes(ShardingTableRuleConfiguration ruleConfig, String actualDataNodes){
        try {
            Field field = ShardingTableRuleConfiguration.class.getDeclaredField("actualDataNodes");
            field.setAccessible(true);
            field.set(ruleConfig, actualDataNodes);
        } catch (NoSuchFieldException | IllegalAccessException e) {
            log.error(e.getMessage());
        }
    }
}

5. 初始化缓存类 ShardingTablesLoadRunner.java 

package com.base.shading;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.YamlPropertiesFactoryBean;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ClassPathResource;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

@Order(value = 1) // 数字越小,越先执行
@Component
public class ShardingTablesLoadRunner implements CommandLineRunner {

    @Autowired
    private Environment environment;

    private static LocalDateTime firstTableLocalDateTime;

    private static final List<String> ShardingTableNames = Arrays.asList("work_procedure","work_manage","work_procedure_file","work_result");


    @Override
    public void run(String... args) {
        checkAndCreatTable();
    }

    @Scheduled(cron = "0 0 23 * * ?")
    public void Scheduled(){
        final Calendar c = Calendar.getInstance();
        /**
         * c.get(Calendar.DATE) 当前时间
         * c.getActualMaximum(Calendar.DATE) 本月最后一日
         */
        if (c.get(Calendar.DATE) == c.getActualMaximum(Calendar.DATE)) {
            checkAndCreatTable();
        }
    }

    /**
     *  检查逻辑表与实际表的是否一致,不一致创建实际表
     */
    public void checkAndCreatTable(){
        // 删除旧的缓存,重新缓存数据库实际存在的分表
        ShardingAlgorithmTool.tableNameCacheReload(false);
        // 数据库中的表
        HashSet<String> hasTables = ShardingAlgorithmTool.cacheTableNames();
        //从第一张表开始到目前所有的表
        HashSet<String> logicTables = new HashSet<>();
        // 已经有分表
        if(hasTables.size() > 0){
            if(firstTableLocalDateTime == null){
                List<String> tableList = hasTables.stream().map(ele->ele.substring(ele.length()-7)).distinct().sorted().collect(Collectors.toList());
                if(tableList.size() > 0 ){
                    String replace = tableList.get(0);
                    String[] s = replace.split("_");
                    firstTableLocalDateTime = LocalDateTime.of(Integer.parseInt(s[0]), Integer.parseInt(s[1]), 1, 0, 0, 0);
                }
            }
            if(firstTableLocalDateTime != null){
                LocalDateTime firstTime = firstTableLocalDateTime;
                while (firstTime.isBefore(LocalDateTime.now().plusMonths(1))){
                    String tableSuffix = firstTime.format(DateTimeFormatter.ofPattern("yyyy_MM"));
                    ShardingTableNames.forEach(ele->logicTables.add(ele.concat("_").concat(tableSuffix)));
                    firstTime = firstTime.plusMonths(1);
                }
            }
        }
        // 还没有分表
        if(hasTables.size() ==0){
            ShardingTableNames.forEach(item->{
                String tableSuffix = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy_MM"));
                String actualTableName = item.concat("_").concat(tableSuffix);
                logicTables.add(actualTableName);
            });
        }
        // 求差集
        logicTables.removeAll(hasTables);
        if(!logicTables.isEmpty()){
            // 创建表
            logicTables.forEach(ele->ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(ele.substring(0,ele.length()-8),ele));
        }
    }

    /**
     * 样式:actual-data-nodes: master.work_manage_${2023..2050}_0${1..9},master.work_manage_${2023..2050}_${10..12}
     * 启动时,读取配置文件中,创建分表
     */
    private void generatable(){
        //获取分表名称
        List<String> genTableNameList = generateTableNameList();
        // 读取已有分表,进行缓存
        ShardingAlgorithmTool.tableNameCacheReload(false);
        //检查表是否存在,不存在的话创建表
        genTableNameList.forEach(ele->{
            String[] s = ele.split("_");
            StringBuilder stringBuilder = new StringBuilder();
            Pattern pattern = Pattern.compile("^(\\d+)(.*)");
            for (String value : s) {
                Matcher matcher = pattern.matcher(value);
                if (!matcher.matches()) {
                    stringBuilder.append(value).append("_");
                }
            }
            String logicTableName = stringBuilder.toString();
            if(logicTableName.endsWith("_")){
                logicTableName = logicTableName.substring(0, logicTableName.length()-1);
            }
            ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(logicTableName, ele);
        });
    }

    /**
     *  生成要分表的名称
     * @return 配置文件中配置的分表名称列表
     */
    private List<String> generateTableNameList(){
        List<String> propertyList = new ArrayList<>();
        List<String> expressList = new ArrayList<>();
        List<String> genTableNameList = new ArrayList<>();
        YamlPropertiesFactoryBean factoryBean = new YamlPropertiesFactoryBean();
        String[] activeProfiles = environment.getActiveProfiles();
        factoryBean.setResources(new ClassPathResource("application-"+activeProfiles[0]+".yml"));
        Properties properties = factoryBean.getObject();
        assert properties != null;
        properties.keySet().forEach(key -> {
            if(key.toString().contains("actual-data-nodes")){
                String value = (String)properties.get(key);
                propertyList.add(value);
            }
        });
        //从配置文件列表中获取配置的表达式
        propertyList.forEach(ele->{
            if(ele.contains(",")){
                String[] split = ele.split(",");
                expressList.addAll(Arrays.asList(split));
            }
        });
        //根据行表达式生成表名
        expressList.forEach(ele->{
            List<String> list = ShardingAlgorithmTool.generateTableNames(ele);
            genTableNameList.addAll(list);
        });
        return genTableNameList;
    }
}

6. Spring工具类 ApplicationContextUtil .java

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;

@Component
@Slf4j
public class ApplicationContextUtil implements ApplicationContextAware {

    private static ApplicationContext applicationContext = null;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        if(ApplicationContextUtil.applicationContext == null) {
            ApplicationContextUtil.applicationContext = applicationContext;
        }
    }

    public static ApplicationContext getApplicationContext() {
        return ApplicationContextUtil.applicationContext;
    }

    public static <T> T getBean(Class<T> clazz) {
        return applicationContext.getBean(clazz);
    }

    public static <T> T getBean(String name, Class<T> clazz) {
        return applicationContext.getBean(name, clazz);
    }

    public static String getProperty(String key) {
        return applicationContext.getBean(Environment.class).getProperty(key);
    }
    public static Object getBean(String name){
        return applicationContext.getBean(name);
    }
}

7、注意

(1)SQL中where后面的查询条件包含分表规则中配置的sharding-column字段才会进入,定义的algorithmClassName分片算法类中。

(2)自定义SQL中表名和sharding-column字段最好不要用双引号包括。

(3)IDEA中yml配置文件中有些属性显示红色,也不要紧,可以正常启动。

(4)对于不携带分片键的 SQL,则采取广播路由的方式。

(5)查询前,需要把配置文件中的分表表创建出来,不然会报表不存在的错误。

(6)select 后面的查询字段,出现数据库保留字段,会报语法错误:

org.apache.shardingsphere.sql.parser.exception.SQLParsingException: You have an error in

your SQL syntax,如path,使用注解@TableField("\"path\"")加上双引号,可解决。

(7)逻辑分表对应的表要全部存在,或者使用Mybatis-plus拦截器拦截,强制加上分片字段,使其进入自定义分表逻辑。

  (8)  多张表查询时,数据库中实际有的表与配置的 actual-data-nodes 的集合,完全一样才能使用 order by, group by ,count(*)等操作。

(9)启动时,actual-data-nodes的结果集合不能为空。

8、Mybatis-plus自定义拦截器

package com.base.mybatis.config;

import com.base.shading.ShardingAlgorithmTool;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.SQLException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashSet;
import java.util.Properties;

import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

@Intercepts({
        @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class,
                ResultHandler.class}),
        @Signature(method = "update",type = Executor.class,  args = {MappedStatement.class, Object.class})
})
public class CustomInterceptor implements Interceptor {

    private static String start = "";
    private static final List<String> tableNames = Arrays.asList("work_procedure","work_manage","work_procedure_file","work_result");
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 获取sql
        String sql = getSqlByInvocation(invocation);
        if (StringUtils.isBlank(sql)) {
            return invocation.proceed();
        }
        boolean flag = false;
        String tableName = "";
        String temSql = sql.toLowerCase();
        //判断是否含有create_time的查询条件
        for(String name:tableNames){
            if(temSql.contains(name)){
                String sqlCondition = "";
                if(temSql.contains("where")){
                    sqlCondition = temSql.substring(temSql.indexOf("where"));
                }
                sqlCondition = sqlCondition.replace("order by create_time", "");
                if(!sqlCondition.contains("create_time")){
                    flag = true;
                    tableName = name;
                    break;
                }
            }
        }

        if(flag){
            String sql2Reset = "";
            final String express = tableName + "_\\d{4}_\\d{2}$";
            String end = LocalDateTime.now().withHour(23).withMinute(59).withSecond(59).format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
            LocalDateTime localDateTime = LocalDateTime.now().withDayOfMonth(1).withHour(0).withMinute(0).withSecond(0);
            if(start.equals("")){
                HashSet<String> tableNameSet = ShardingAlgorithmTool.cacheTableNames();
                List<String> tableList = tableNameSet.stream().filter(element ->  Pattern.matches(express, element)).sorted().collect(Collectors.toList());
                if(tableList.size() > 0 ){
                    String replace = tableList.get(0).replace(tableName + "_", "");
                    String[] s = replace.split("_");
                    localDateTime = LocalDateTime.of(Integer.parseInt(s[0]), Integer.parseInt(s[1]), 1, 0, 0, 0);
                }
                 start = localDateTime.format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
            }

            if(sql.contains("where")){
                sql2Reset = sql.replace("where","where (create_time between '" + start + "' and '" + end + "' ) and");
            } else if (sql.contains("WHERE")) {
                sql2Reset = sql.replace("WHERE","WHERE (create_time between '" + start + "' and '" + end + "' ) and");
            }
            if(!sql2Reset.equals("")){
                // 包装sql后,重置到invocation中
                resetSql2Invocation(invocation, sql2Reset);
            }
        }

        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 可以在这里设置拦截器的属性
    }

    /**
     * 获取sql语句
     * @param invocation
     * @return
     */
    private String getSqlByInvocation(Invocation invocation) {
        final Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameterObject = args[1];
        BoundSql boundSql = ms.getBoundSql(parameterObject);
        return boundSql.getSql();
    }

    /**
     * 包装sql后,重置到invocation中
     * @param invocation
     * @param sql
     * @throws
     */
    private void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
        final Object[] args = invocation.getArgs();
        MappedStatement statement = (MappedStatement) args[0];
        Object parameterObject = args[1];
        BoundSql boundSql = statement.getBoundSql(parameterObject);
        MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
        MetaObject msObject =  MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
        msObject.setValue("sqlSource.boundSql.sql", sql);
        args[0] = newStatement;
    }
    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder =
                new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());

        return builder.build();
    }

    private String getOperateType(Invocation invocation) {
        final Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        SqlCommandType commondType = ms.getSqlCommandType();
        if (commondType.compareTo(SqlCommandType.SELECT) == 0) {
            return "select";
        }
        if (commondType.compareTo(SqlCommandType.INSERT) == 0) {
            return "insert";
        }
        if (commondType.compareTo(SqlCommandType.UPDATE) == 0) {
            return "update";
        }
        if (commondType.compareTo(SqlCommandType.DELETE) == 0) {
            return "delete";
        }
        return null;
    }
    //    定义一个内部辅助类,作用是包装sq
    class BoundSqlSqlSource implements SqlSource {
        private BoundSql boundSql;
        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }
        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

}

9、mybatis-plus配置

package com.base.mybatis.config;

import cn.hutool.core.util.ArrayUtil;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
import com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor;
import com.base.property.TenantLineProperty;
import com.base.utils.MachineUtil;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.NullValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Arrays;
import java.util.List;

@Configuration
public class MybatisPlusConfig {

    @Autowired
    private TenantLineProperty tenantLineProperty;

    @Autowired
    private MachineUtil machineUtil;

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 多租户插件
        interceptor.addInnerInterceptor(new TenantLineInnerInterceptor(new TenantLineHandler() {
            @Override
            public Expression getTenantId() {
                String code = machineUtil.getMachineCode();
                return StringUtils.isEmpty(code) ? new NullValue() : new StringValue(code);
            }

            // 这是 default 方法,默认返回 false 表示所有表都需要拼多租户条件
            @Override
            public boolean ignoreTable(String tableName) {
                String[] ignoreTables = tenantLineProperty.getIgnoreTables();
                final boolean LIST_EMPTY = !ArrayUtil.isEmpty(ignoreTables);
                final boolean CONTAINS = Arrays.stream(ignoreTables).anyMatch(i -> i.equals(tableName.replace("\"", "").replace("`", "")));
                return LIST_EMPTY && CONTAINS;
            }

            @Override
            public String getTenantIdColumn() {
                return "machine_code";
            }

            @Override
            public boolean ignoreInsert(List<Column> columns, String tenantIdColumn) {
                return columns.stream().map(Column::getColumnName).anyMatch(i -> i.equalsIgnoreCase(tenantIdColumn));
            }
        }));
        // 新的分页插件,一缓和二缓遵循mybatis的规则,需要设置 MybatisConfiguration#useDeprecatedExecutor = false 避免缓存出现问题(该属性会在旧插件移除后一同移除)
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.POSTGRE_SQL));
        // 乐观锁插件
        interceptor.addInnerInterceptor(new OptimisticLockerInnerInterceptor());
        return interceptor;
    }

    @Bean
    public CustomInterceptor mybatisSqlInterceptor(){
        return  new CustomInterceptor();
    }

}

 9、创建索引

package com.base.shading;

import com.base.event.UnisicBaseEvent;
import com.base.utils.TimeUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.event.EventListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;


import java.beans.Transient;
import java.sql.*;
import java.util.*;

@Slf4j
@Component
public class CreateTableIndex {

    private static Connection connection;
    @Value("${spring.shardingsphere.datasource.master.url}")
    public String url;

    @Value("${spring.shardingsphere.datasource.master.username}")
    public String username;

    @Value("${spring.shardingsphere.datasource.master.password}")
    public String password;

    @Async
    @Transient
    @EventListener(classes = {UnisicBaseEvent.class}, condition = "T(com.unisic.base.shading.ConstantInterface).CREATE_INDEX.equals(#event.topic)")
    public synchronized void procedureFinish(UnisicBaseEvent<List<String>> event) {
        TimeUtil.delay(3);
        List<String> tableList = event.getData();
        String logicTableName = tableList.get(0);
        String actualTableName = tableList.get(1);
        String sqlQuery = "SELECT \n" +
                "    i.relname AS index_name,\n" +
                "    a.attname AS column_name\n" +
                "FROM \n" +
                "    pg_index idx\n" +
                "JOIN \n" +
                "    pg_class t ON t.oid = idx.indrelid\n" +
                "JOIN \n" +
                "    pg_class i ON i.oid = idx.indexrelid\n" +
                "JOIN \n" +
                "    pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(idx.indkey)\n" +
                "WHERE \n" +
                "    t.relname = '" + logicTableName +"';";
        //  用来存储索引和字段
        Map<String,List<String>> indexMap = new HashMap<>();
        try {
//            DataSource dataSource = ApplicationContextUtil.getBean(DataSource.class);
//            Connection connection = dataSource.getConnection();
//            Statement statement = connection.createStatement();
            if (connection == null || connection.isClosed()) {
                connection = DriverManager.getConnection(url, username, password);
            }
            Statement statement = connection.createStatement();
            ResultSet rs = statement.executeQuery(sqlQuery);
            while (rs.next()) {
                String index_name = rs.getString(1);
                String column_name = rs.getString(2);
                if(indexMap.containsKey(index_name)){
                    List<String> list = indexMap.get(index_name);
                    list.add(column_name);
                }else{
                    ArrayList<String> list = new ArrayList<>();
                    list.add(column_name);
                    indexMap.put(index_name,list);
                }
            }
            if(indexMap.size() > 0){
                indexMap.keySet().forEach(element->{
                    List<String> list = indexMap.get(element);
                    String columnName = String.join(",", list);
                    String indexName = actualTableName + "_" + String.join("_", list);
                    String indexSql = "CREATE INDEX "+indexName +" ON "+ actualTableName +" ("+columnName+");";
                    try {
                        statement.executeUpdate(indexSql);
                    } catch (SQLException e) {
                        log.info("create-index-SQLException: " + e.getMessage());
                        throw new RuntimeException(e);
                    }
                });
            }
        } catch (SQLException e) {
            log.info("SQLException: " + e.getMessage());
            throw new RuntimeException(e);
        }
    }
}

  • 5
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
ShardingSphere:SpringBoot2+MybatisPlus+读写分离+分库分表课程目标快速的掌握读写分离+分表的实战,即插即用适用人群IT从业人员,开发人员,Java从业者,互联网从业者,性能调优人群课程简介ShardingSphere是一套开源的分布式数据库中间件解决方案组成的生态圈。它由Sharding-JDBCSharding-Proxy和Sharding-Sidecar(计划中)这3款相互独立的产品组成,shardingSphere定位为关系型数据库中间件。 Sharding-JDBCSharding-JDBCSharding-Sphere的第一个产品,也是Sharding-Sphere的前身,是当当网开源的一个产品。定位为轻量级的Java框架,在JavaJDBC层提供额外服务。 它使用客户端直连数据库,以jar包形式提供服务,无需额外部署和依赖,可理解为增强版的JDBC驱动,完全兼容JDBC和各种ORM框架。他们均提供标准化的数据分片、读写分离、柔性事务和数据治理功能,可适用于如Java同构、异构语言、容器、云原生等各种多样化的应用场景。Sharding-JDBC可以通过Java,YAML,Spring命名空间和Spring Boot Starter四种方式配置,开发者可根据场景选择适合的配置方式。课程特色 本章节以尽量短的时间,为使用者提供最简单的ShardingSphere的快速入门。课程说明该课程属于系列课程,分为读写分离,分库不分表,不分库分表,分库分表,读写分离+分库分表共5个回合。本课程属于其中一个回合,请各位小哥哥们注意,课程的标题哦~

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值