1、引入Maven 依赖
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>shardingsphere-jdbc-core-spring-boot-starter</artifactId>
<version>5.1.0</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.2.8</version>
</dependency>
<dependency>
<groupId>org.apache.tomcat</groupId>
<artifactId>tomcat-dbcp</artifactId>
<version>10.0.16</version>
</dependency>
2、yml 配置文件
spring:
sharding-sphere:
datasource:
names: master
master:
type: com.alibaba.druid.pool.DruidDataSource
driver-class-name: org.postgresql.Driver
url: jdbc:postgresql://127.0.0.1:5432/production_dev?serverTimezone=UTC&characterEncoding=utf-8&stringtype=unspecified
username: postgres
password: 123456
rules:
sharding:
sharding-algorithms:
month-sharding-algorithm:
props:
strategy: standard
algorithmClassName: com.base.shading.CreateTimeShardingAlgorithm
type: CLASS_BASED
tables:
work_procedure:
# actual-data-nodes: master.work_result_${2023..2050}_0${1..9},master.work_result_${2023..2050}_${10..12}
actual-data-nodes: master.$->{com.base.shading.ShardingAlgorithmTool.cacheTableNames()}
table-strategy:
standard:
sharding-column: create_time
sharding-algorithm-name: month-sharding-algorithm
work_result:
# actual-data-nodes: master.work_result_${2023..2050}_0${1..9},master.work_result_${2023..2050}_${10..12}
actual-data-nodes: master.$->{com.base.shading.ShardingAlgorithmTool.cacheTableNames()}
table-strategy:
standard:
sharding-column: create_time
sharding-algorithm-name: month-sharding-algorithm
bindingTables:
- work_procedure, work_result
props:
sql-show: true
3、分片算法类 CreateTimeShardingAlgorithm.java
package com.base.shading;
import com.google.common.collect.Range;
import org.apache.shardingsphere.sharding.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.StandardShardingAlgorithm;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.*;
public class CreateTimeShardingAlgorithm implements StandardShardingAlgorithm<Date> {
private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
/**
* 精准分片
* @param collection 对应分片库中所有分片表的集合
* @param preciseShardingValue 分片键值,其中 logicTableName 为逻辑表,columnName 分片键,value 为从 SQL 中解析出来的分片键的值
* @return 表名
*/
@Override
public String doSharding(Collection<String> collection, PreciseShardingValue<Date> preciseShardingValue) {
Object value = preciseShardingValue.getValue();
String tableSuffix = null;
if(value instanceof Date){
//Date columnValue = (Date) value;
//tableSuffix = columnValue.toInstant().atZone(ZoneId.systemDefault()).format(DateTimeFormatter.ofPattern("yyyy_MM"));
LocalDate localDate = ((Date) value).toInstant().atZone(ZoneId.systemDefault()).toLocalDate();
tableSuffix = localDate.format(DateTimeFormatter.ofPattern("yyyy_MM"));
}else{
String column = (String)value;
tableSuffix = LocalDateTime.parse(column, formatter).format(DateTimeFormatter.ofPattern("yyyy_MM"));
}
String logicTableName = preciseShardingValue.getLogicTableName();
String actualTableName = logicTableName.concat("_").concat(tableSuffix);
if(!collection.contains(actualTableName)){
collection.add(actualTableName);
}
return ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(logicTableName, actualTableName);
}
/**
* 范围分片
* @param collection 对应分片库中所有分片表的集合
* @param rangeShardingValue 分片范围
* @return 表名集合
*/
@Override
public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<Date> rangeShardingValue) {
// 逻辑表名
String logicTableName = rangeShardingValue.getLogicTableName();
// 范围参数
Range<Date> valueRange = rangeShardingValue.getValueRange();
//起始时间 结束时间
LocalDateTime start = null;
LocalDateTime end = null;
Object lowerEndpoint = (Object)valueRange.lowerEndpoint();
Object upperEndpoint = (Object)valueRange.upperEndpoint();
if(lowerEndpoint instanceof String){
String lower = (String) lowerEndpoint;
String upper = (String) upperEndpoint;
start = LocalDateTime.parse(lower,formatter);
end = LocalDateTime.parse(upper,formatter);
}else{
start = valueRange.lowerEndpoint().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
end = valueRange.upperEndpoint().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
}
if(end.isAfter(LocalDateTime.now())){
end = LocalDateTime.now();
}
// 查询范围的表
Set<String> queryRangeTables = extracted(logicTableName, start, end);
// 数据库中的表
HashSet<String> tableNameSet = ShardingAlgorithmTool.cacheTableNames();
//
if(collection.size() != tableNameSet.size()){
collection.clear();
collection.addAll(tableNameSet);
}
//返回的表
ArrayList<String> tables = new ArrayList<>(tableNameSet);
tables.retainAll(queryRangeTables);
return tables;
}
@Override
public String getType() {
return null;
}
@Override
public Properties getProps() {
return null;
}
@Override
public void init() {
}
/**
* 根据范围计算表名
*
* @param logicTableName 逻辑表名
* @param lowerEndpoint 范围起点
* @param upperEndpoint 范围终端
* @return 物理表名集合
*/
private Set<String> extracted(String logicTableName, LocalDateTime lowerEndpoint, LocalDateTime upperEndpoint) {
Set<String> rangeTable = new HashSet<>();
while (lowerEndpoint.isBefore(upperEndpoint)) {
String str = getTableNameByDate(lowerEndpoint, logicTableName);
rangeTable.add(str);
lowerEndpoint = lowerEndpoint.plusMonths(1);
}
// 获取物理表名
String tableName = getTableNameByDate(upperEndpoint, logicTableName);
rangeTable.add(tableName);
return rangeTable;
}
/**
* 根据日期获取表名
* @param dateTime 日期
* @param logicTableName 逻辑表名
* @return 物理表名
*/
private String getTableNameByDate(LocalDateTime dateTime, String logicTableName) {
String tableSuffix = dateTime.format(DateTimeFormatter.ofPattern("yyyy_MM"));
return logicTableName.concat("_").concat(tableSuffix);
}
}
4、 分片工具类 ShardingAlgorithmTool.java
package com.base.shading;
import com.base.event.UnisicBaseEvent;
import com.base.utils.ApplicationContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.driver.jdbc.core.datasource.ShardingSphereDataSource;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.*;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.shardingsphere.infra.config.RuleConfiguration;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.sharding.algorithm.config.AlgorithmProvidedShardingRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.ShardingRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.rule.ShardingTableRuleConfiguration;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.ShardingStrategyConfiguration;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.interceptor.TransactionAspectSupport;
import static com.unisic.base.shading.ConstantInterface.CREATE_INDEX;
@Slf4j
@Component
public class ShardingAlgorithmTool {
private static final String logicDb = "logic_db";
@Resource
private ShardingSphereDataSource shardingSphereDataSource;
@Autowired
private ApplicationContext applicationContext;
private static ApplicationContext context;
private static ShardingSphereDataSource shardingDataSource;
private static final HashSet<String> tableNameCache = new HashSet<>();
private static final List<String> ShardingTableNames = Arrays.asList("work_procedure","work_manage","work_procedure_file","work_result");
// 启动时,实际表中要有值,启动后,在ShardingTablesLoadRunner中先清空在缓存
static {
ShardingTableNames.forEach(item->{
String tableSuffix = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy_MM"));
String actualTableName = item.concat("_").concat(tableSuffix);
tableNameCache.add(actualTableName);
});
}
@PostConstruct
public void init() {
shardingDataSource = shardingSphereDataSource;
context = applicationContext;
}
/**
* 获取所有表名
* @return 表名集合
*/
public static List<String> getAllTableNameBySchema() {
List<String> tableNames = new ArrayList<>();
String sql = "SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename ~ '_\\d{4}_\\d{2}$';";
DataSource dataSource = ApplicationContextUtil.getBean(DataSource.class);
try (Connection connection = dataSource.getConnection()) {
Statement statement = connection.createStatement();
try (ResultSet rs = statement.executeQuery(sql)) {
while (rs.next()) {
String actualTableName = rs.getString(1);
tableNames.add(actualTableName);
}
}
} catch (SQLException e) {
log.info("SQLException: " + e.getMessage());
throw new RuntimeException(e);
}
return tableNames;
}
/**
*
* 判断 分表获取的表名是否存在 不存在则自动建表
*
* @param logicTableName 逻辑表名(表头)
* @param actualTableName 真实表名
* @return 确认存在于数据库中的真实表名
*/
public static String shardingTablesCheckAndCreatAndReturn(String logicTableName, String actualTableName) {
synchronized (logicTableName.intern()) {
// 缓存中有此表 返回
if (tableNameCache.contains(actualTableName)) {
return actualTableName;
}
// 建表sql
String createTableSql="CREATE TABLE "+ actualTableName +" (LIKE "+ logicTableName +" );";
// 主键sql
String createPrimaryKeySql="ALTER TABLE " + actualTableName + " ADD CONSTRAINT "+ actualTableName +"_pk PRIMARY KEY (id);";
// 缓存中无此表,则建表,创建索引并添加缓存
DataSource dataSource = ApplicationContextUtil.getBean(DataSource.class);
try {
Connection connection = dataSource.getConnection();
Statement statement = connection.createStatement();
try {
statement.executeUpdate(createTableSql);
} catch (SQLException e) {
log.info("create-table-SQLException: " + e.getMessage());
throw new RuntimeException(e);
}
try {
statement.executeUpdate(createPrimaryKeySql);
} catch (SQLException e) {
log.info("create-primaryKey-SQLException: " + e.getMessage());
throw new RuntimeException(e);
}
ArrayList<String> names = new ArrayList<>(2);
names.add(logicTableName);
names.add(actualTableName);
context.publishEvent(new UnisicBaseEvent<>(names,ConstantInterface.CREATE_INDEX));
} catch (SQLException e) {
log.info("SQLException: " + e.getMessage());
throw new RuntimeException(e);
}
// 缓存重载
tableNameCacheReload(true);
}
return actualTableName;
}
/**
* 缓存重载方法
*/
public static void tableNameCacheReload(boolean flag) {
// 读取数据库中所有表名
List<String> tableNameList = getAllTableNameBySchema();
// 删除旧的缓存(如果存在)
ShardingAlgorithmTool.tableNameCache.clear();
// 写入新的缓存
ShardingAlgorithmTool.tableNameCache.addAll(tableNameList);
if(flag){
reloadShardRuleActualDataNodes(shardingDataSource,logicDb);
}
}
/**
* 获取缓存中的表名
* @return
*/
public static HashSet<String> cacheTableNames() {
return tableNameCache;
}
/**
* 根据行表达式生成表名
* @param expression
* @return
*/
public static List<String> generateTableNames(String expression) {
List<String> tableNames = new ArrayList<>();
List<String> returnTableNames = new ArrayList<>();
List<String> extractedValues = extractValues(expression);
if(expression.contains(".")){
expression = expression.substring(expression.indexOf(".")+1);
}
String replacedExpression = expression.replaceAll("\\$\\{.*?}", "#");
for(int i=0; i < extractedValues.size(); i++ ){
String s = extractedValues.get(i).replaceAll("\\.", "#");
String[] split = s.split("##");
int start = Integer.parseInt(split[0]);
int end = Integer.parseInt(split[1]);
if(i == 0){
for(int j=start; j <= end; j++){
String replace = replacedExpression.replaceFirst("#", String.valueOf(j));
tableNames.add(replace);
}
}else{
for(int k=0; k < tableNames.size(); k++){
for(int j=start; j <= end; j++){
String s1 = tableNames.get(k).replaceFirst("#", String.valueOf(j));
returnTableNames.add(s1);
}
}
}
}
return returnTableNames;
}
/**
* 提取行表达式中的${}中范围值
* @param expression 行表达式
* @return
*/
private static List<String> extractValues(String expression) {
List<String> extractedValues = new ArrayList<>();
Pattern pattern = Pattern.compile("\\$\\{([^}]*)\\}");
Matcher matcher = pattern.matcher(expression);
while (matcher.find()) {
String extractedValue = matcher.group(1);
extractedValues.add(extractedValue);
}
return extractedValues;
}
private static void reloadShardRuleActualDataNodes(ShardingSphereDataSource dataSource, String schemaName) {
// Context manager.
org.apache.shardingsphere.mode.manager.ContextManager contextManager = dataSource.getContextManager();
// Rule configuration.
Collection<RuleConfiguration> oldRuleConfigList = dataSource.getContextManager()
.getMetaDataContexts()
.getMetaData(schemaName)
.getRuleMetaData()
.getConfigurations();
for (RuleConfiguration config : oldRuleConfigList) {
Collection<ShardingTableRuleConfiguration> tables = ((AlgorithmProvidedShardingRuleConfiguration) config).getTables();
tables.forEach(logicTable->{
String logicTableName = logicTable.getLogicTable();
if(ShardingTableNames.contains(logicTableName)) {
setActualDataNodes(logicTable,"master.$->{com.unisic.base.shading.ShardingAlgorithmTool.cacheTableNames()}");
}
});
}
contextManager.alterRuleConfiguration(logicDb, oldRuleConfigList);
// Collection<RuleConfiguration> newRuleConfigList = new LinkedList<>();
// for (RuleConfiguration oldRuleConfig : oldRuleConfigList) {
// if (oldRuleConfig instanceof AlgorithmProvidedShardingRuleConfiguration) {
//
// // Algorithm provided sharding rule configuration
// AlgorithmProvidedShardingRuleConfiguration oldAlgorithmConfig = (AlgorithmProvidedShardingRuleConfiguration) oldRuleConfig;
// AlgorithmProvidedShardingRuleConfiguration newAlgorithmConfig = new AlgorithmProvidedShardingRuleConfiguration();
//
// // Sharding table rule configuration Collection
// Collection<ShardingTableRuleConfiguration> newTableRuleConfigList = new LinkedList<>();
// Collection<ShardingTableRuleConfiguration> oldTableRuleConfigList = oldAlgorithmConfig.getTables();
//
// oldTableRuleConfigList.forEach(oldTableRuleConfig -> {
// if (tableNameCache.contains(oldTableRuleConfig.getLogicTable())) {
// ShardingTableRuleConfiguration newTableRuleConfig = new ShardingTableRuleConfiguration(oldTableRuleConfig.getLogicTable(), "master.$->{com.unisic.base.shading.ShardingAlgorithmTool.cacheTableNames()}");
// newTableRuleConfig.setTableShardingStrategy(oldTableRuleConfig.getTableShardingStrategy());
// newTableRuleConfig.setDatabaseShardingStrategy(oldTableRuleConfig.getDatabaseShardingStrategy());
// newTableRuleConfig.setKeyGenerateStrategy(oldTableRuleConfig.getKeyGenerateStrategy());
//
// newTableRuleConfigList.add(newTableRuleConfig);
// } else {
// newTableRuleConfigList.add(oldTableRuleConfig);
// }
// });
//
// newAlgorithmConfig.setTables(newTableRuleConfigList);
// newAlgorithmConfig.setAutoTables(oldAlgorithmConfig.getAutoTables());
// newAlgorithmConfig.setBindingTableGroups(oldAlgorithmConfig.getBindingTableGroups());
// newAlgorithmConfig.setBroadcastTables(oldAlgorithmConfig.getBroadcastTables());
// newAlgorithmConfig.setDefaultDatabaseShardingStrategy(oldAlgorithmConfig.getDefaultDatabaseShardingStrategy());
// newAlgorithmConfig.setDefaultTableShardingStrategy(oldAlgorithmConfig.getDefaultTableShardingStrategy());
// newAlgorithmConfig.setDefaultKeyGenerateStrategy(oldAlgorithmConfig.getDefaultKeyGenerateStrategy());
// newAlgorithmConfig.setDefaultShardingColumn(oldAlgorithmConfig.getDefaultShardingColumn());
// newAlgorithmConfig.setShardingAlgorithms(oldAlgorithmConfig.getShardingAlgorithms());
// newAlgorithmConfig.setKeyGenerators(oldAlgorithmConfig.getKeyGenerators());
//
// newRuleConfigList.add(newAlgorithmConfig);
// }
// }
//
// // update context
// contextManager.alterRuleConfiguration(schemaName, newRuleConfigList);
}
private static void setActualDataNodes(ShardingTableRuleConfiguration ruleConfig, String actualDataNodes){
try {
Field field = ShardingTableRuleConfiguration.class.getDeclaredField("actualDataNodes");
field.setAccessible(true);
field.set(ruleConfig, actualDataNodes);
} catch (NoSuchFieldException | IllegalAccessException e) {
log.error(e.getMessage());
}
}
}
5. 初始化缓存类 ShardingTablesLoadRunner.java
package com.base.shading;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.YamlPropertiesFactoryBean;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ClassPathResource;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Order(value = 1) // 数字越小,越先执行
@Component
public class ShardingTablesLoadRunner implements CommandLineRunner {
@Autowired
private Environment environment;
private static LocalDateTime firstTableLocalDateTime;
private static final List<String> ShardingTableNames = Arrays.asList("work_procedure","work_manage","work_procedure_file","work_result");
@Override
public void run(String... args) {
checkAndCreatTable();
}
@Scheduled(cron = "0 0 23 * * ?")
public void Scheduled(){
final Calendar c = Calendar.getInstance();
/**
* c.get(Calendar.DATE) 当前时间
* c.getActualMaximum(Calendar.DATE) 本月最后一日
*/
if (c.get(Calendar.DATE) == c.getActualMaximum(Calendar.DATE)) {
checkAndCreatTable();
}
}
/**
* 检查逻辑表与实际表的是否一致,不一致创建实际表
*/
public void checkAndCreatTable(){
// 删除旧的缓存,重新缓存数据库实际存在的分表
ShardingAlgorithmTool.tableNameCacheReload(false);
// 数据库中的表
HashSet<String> hasTables = ShardingAlgorithmTool.cacheTableNames();
//从第一张表开始到目前所有的表
HashSet<String> logicTables = new HashSet<>();
// 已经有分表
if(hasTables.size() > 0){
if(firstTableLocalDateTime == null){
List<String> tableList = hasTables.stream().map(ele->ele.substring(ele.length()-7)).distinct().sorted().collect(Collectors.toList());
if(tableList.size() > 0 ){
String replace = tableList.get(0);
String[] s = replace.split("_");
firstTableLocalDateTime = LocalDateTime.of(Integer.parseInt(s[0]), Integer.parseInt(s[1]), 1, 0, 0, 0);
}
}
if(firstTableLocalDateTime != null){
LocalDateTime firstTime = firstTableLocalDateTime;
while (firstTime.isBefore(LocalDateTime.now().plusMonths(1))){
String tableSuffix = firstTime.format(DateTimeFormatter.ofPattern("yyyy_MM"));
ShardingTableNames.forEach(ele->logicTables.add(ele.concat("_").concat(tableSuffix)));
firstTime = firstTime.plusMonths(1);
}
}
}
// 还没有分表
if(hasTables.size() ==0){
ShardingTableNames.forEach(item->{
String tableSuffix = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy_MM"));
String actualTableName = item.concat("_").concat(tableSuffix);
logicTables.add(actualTableName);
});
}
// 求差集
logicTables.removeAll(hasTables);
if(!logicTables.isEmpty()){
// 创建表
logicTables.forEach(ele->ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(ele.substring(0,ele.length()-8),ele));
}
}
/**
* 样式:actual-data-nodes: master.work_manage_${2023..2050}_0${1..9},master.work_manage_${2023..2050}_${10..12}
* 启动时,读取配置文件中,创建分表
*/
private void generatable(){
//获取分表名称
List<String> genTableNameList = generateTableNameList();
// 读取已有分表,进行缓存
ShardingAlgorithmTool.tableNameCacheReload(false);
//检查表是否存在,不存在的话创建表
genTableNameList.forEach(ele->{
String[] s = ele.split("_");
StringBuilder stringBuilder = new StringBuilder();
Pattern pattern = Pattern.compile("^(\\d+)(.*)");
for (String value : s) {
Matcher matcher = pattern.matcher(value);
if (!matcher.matches()) {
stringBuilder.append(value).append("_");
}
}
String logicTableName = stringBuilder.toString();
if(logicTableName.endsWith("_")){
logicTableName = logicTableName.substring(0, logicTableName.length()-1);
}
ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(logicTableName, ele);
});
}
/**
* 生成要分表的名称
* @return 配置文件中配置的分表名称列表
*/
private List<String> generateTableNameList(){
List<String> propertyList = new ArrayList<>();
List<String> expressList = new ArrayList<>();
List<String> genTableNameList = new ArrayList<>();
YamlPropertiesFactoryBean factoryBean = new YamlPropertiesFactoryBean();
String[] activeProfiles = environment.getActiveProfiles();
factoryBean.setResources(new ClassPathResource("application-"+activeProfiles[0]+".yml"));
Properties properties = factoryBean.getObject();
assert properties != null;
properties.keySet().forEach(key -> {
if(key.toString().contains("actual-data-nodes")){
String value = (String)properties.get(key);
propertyList.add(value);
}
});
//从配置文件列表中获取配置的表达式
propertyList.forEach(ele->{
if(ele.contains(",")){
String[] split = ele.split(",");
expressList.addAll(Arrays.asList(split));
}
});
//根据行表达式生成表名
expressList.forEach(ele->{
List<String> list = ShardingAlgorithmTool.generateTableNames(ele);
genTableNameList.addAll(list);
});
return genTableNameList;
}
}
6. Spring工具类 ApplicationContextUtil .java
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class ApplicationContextUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext = null;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if(ApplicationContextUtil.applicationContext == null) {
ApplicationContextUtil.applicationContext = applicationContext;
}
}
public static ApplicationContext getApplicationContext() {
return ApplicationContextUtil.applicationContext;
}
public static <T> T getBean(Class<T> clazz) {
return applicationContext.getBean(clazz);
}
public static <T> T getBean(String name, Class<T> clazz) {
return applicationContext.getBean(name, clazz);
}
public static String getProperty(String key) {
return applicationContext.getBean(Environment.class).getProperty(key);
}
public static Object getBean(String name){
return applicationContext.getBean(name);
}
}
7、注意
(1)SQL中where后面的查询条件包含分表规则中配置的sharding-column字段才会进入,定义的algorithmClassName分片算法类中。
(2)自定义SQL中表名和sharding-column字段最好不要用双引号包括。
(3)IDEA中yml配置文件中有些属性显示红色,也不要紧,可以正常启动。
(4)对于不携带分片键的 SQL,则采取广播路由的方式。
(5)查询前,需要把配置文件中的分表表创建出来,不然会报表不存在的错误。
(6)select 后面的查询字段,出现数据库保留字段,会报语法错误:
org.apache.shardingsphere.sql.parser.exception.SQLParsingException: You have an error in
your SQL syntax,如path,使用注解@TableField("\"path\"")加上双引号,可解决。
(7)逻辑分表对应的表要全部存在,或者使用Mybatis-plus拦截器拦截,强制加上分片字段,使其进入自定义分表逻辑。
(8) 多张表查询时,数据库中实际有的表与配置的 actual-data-nodes 的集合,完全一样才能使用 order by, group by ,count(*)等操作。
(9)启动时,actual-data-nodes的结果集合不能为空。
8、Mybatis-plus自定义拦截器
package com.base.mybatis.config;
import com.base.shading.ShardingAlgorithmTool;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.SQLException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashSet;
import java.util.Properties;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Intercepts({
@Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class,
ResultHandler.class}),
@Signature(method = "update",type = Executor.class, args = {MappedStatement.class, Object.class})
})
public class CustomInterceptor implements Interceptor {
private static String start = "";
private static final List<String> tableNames = Arrays.asList("work_procedure","work_manage","work_procedure_file","work_result");
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 获取sql
String sql = getSqlByInvocation(invocation);
if (StringUtils.isBlank(sql)) {
return invocation.proceed();
}
boolean flag = false;
String tableName = "";
String temSql = sql.toLowerCase();
//判断是否含有create_time的查询条件
for(String name:tableNames){
if(temSql.contains(name)){
String sqlCondition = "";
if(temSql.contains("where")){
sqlCondition = temSql.substring(temSql.indexOf("where"));
}
sqlCondition = sqlCondition.replace("order by create_time", "");
if(!sqlCondition.contains("create_time")){
flag = true;
tableName = name;
break;
}
}
}
if(flag){
String sql2Reset = "";
final String express = tableName + "_\\d{4}_\\d{2}$";
String end = LocalDateTime.now().withHour(23).withMinute(59).withSecond(59).format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
LocalDateTime localDateTime = LocalDateTime.now().withDayOfMonth(1).withHour(0).withMinute(0).withSecond(0);
if(start.equals("")){
HashSet<String> tableNameSet = ShardingAlgorithmTool.cacheTableNames();
List<String> tableList = tableNameSet.stream().filter(element -> Pattern.matches(express, element)).sorted().collect(Collectors.toList());
if(tableList.size() > 0 ){
String replace = tableList.get(0).replace(tableName + "_", "");
String[] s = replace.split("_");
localDateTime = LocalDateTime.of(Integer.parseInt(s[0]), Integer.parseInt(s[1]), 1, 0, 0, 0);
}
start = localDateTime.format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
}
if(sql.contains("where")){
sql2Reset = sql.replace("where","where (create_time between '" + start + "' and '" + end + "' ) and");
} else if (sql.contains("WHERE")) {
sql2Reset = sql.replace("WHERE","WHERE (create_time between '" + start + "' and '" + end + "' ) and");
}
if(!sql2Reset.equals("")){
// 包装sql后,重置到invocation中
resetSql2Invocation(invocation, sql2Reset);
}
}
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
// 可以在这里设置拦截器的属性
}
/**
* 获取sql语句
* @param invocation
* @return
*/
private String getSqlByInvocation(Invocation invocation) {
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
BoundSql boundSql = ms.getBoundSql(parameterObject);
return boundSql.getSql();
}
/**
* 包装sql后,重置到invocation中
* @param invocation
* @param sql
* @throws
*/
private void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
final Object[] args = invocation.getArgs();
MappedStatement statement = (MappedStatement) args[0];
Object parameterObject = args[1];
BoundSql boundSql = statement.getBoundSql(parameterObject);
MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
msObject.setValue("sqlSource.boundSql.sql", sql);
args[0] = newStatement;
}
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder =
new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
private String getOperateType(Invocation invocation) {
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
SqlCommandType commondType = ms.getSqlCommandType();
if (commondType.compareTo(SqlCommandType.SELECT) == 0) {
return "select";
}
if (commondType.compareTo(SqlCommandType.INSERT) == 0) {
return "insert";
}
if (commondType.compareTo(SqlCommandType.UPDATE) == 0) {
return "update";
}
if (commondType.compareTo(SqlCommandType.DELETE) == 0) {
return "delete";
}
return null;
}
// 定义一个内部辅助类,作用是包装sq
class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
9、mybatis-plus配置
package com.base.mybatis.config;
import cn.hutool.core.util.ArrayUtil;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
import com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor;
import com.base.property.TenantLineProperty;
import com.base.utils.MachineUtil;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.NullValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Arrays;
import java.util.List;
@Configuration
public class MybatisPlusConfig {
@Autowired
private TenantLineProperty tenantLineProperty;
@Autowired
private MachineUtil machineUtil;
@Bean
public MybatisPlusInterceptor mybatisPlusInterceptor() {
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
// 多租户插件
interceptor.addInnerInterceptor(new TenantLineInnerInterceptor(new TenantLineHandler() {
@Override
public Expression getTenantId() {
String code = machineUtil.getMachineCode();
return StringUtils.isEmpty(code) ? new NullValue() : new StringValue(code);
}
// 这是 default 方法,默认返回 false 表示所有表都需要拼多租户条件
@Override
public boolean ignoreTable(String tableName) {
String[] ignoreTables = tenantLineProperty.getIgnoreTables();
final boolean LIST_EMPTY = !ArrayUtil.isEmpty(ignoreTables);
final boolean CONTAINS = Arrays.stream(ignoreTables).anyMatch(i -> i.equals(tableName.replace("\"", "").replace("`", "")));
return LIST_EMPTY && CONTAINS;
}
@Override
public String getTenantIdColumn() {
return "machine_code";
}
@Override
public boolean ignoreInsert(List<Column> columns, String tenantIdColumn) {
return columns.stream().map(Column::getColumnName).anyMatch(i -> i.equalsIgnoreCase(tenantIdColumn));
}
}));
// 新的分页插件,一缓和二缓遵循mybatis的规则,需要设置 MybatisConfiguration#useDeprecatedExecutor = false 避免缓存出现问题(该属性会在旧插件移除后一同移除)
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.POSTGRE_SQL));
// 乐观锁插件
interceptor.addInnerInterceptor(new OptimisticLockerInnerInterceptor());
return interceptor;
}
@Bean
public CustomInterceptor mybatisSqlInterceptor(){
return new CustomInterceptor();
}
}
9、创建索引
package com.base.shading;
import com.base.event.UnisicBaseEvent;
import com.base.utils.TimeUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.event.EventListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.beans.Transient;
import java.sql.*;
import java.util.*;
@Slf4j
@Component
public class CreateTableIndex {
private static Connection connection;
@Value("${spring.shardingsphere.datasource.master.url}")
public String url;
@Value("${spring.shardingsphere.datasource.master.username}")
public String username;
@Value("${spring.shardingsphere.datasource.master.password}")
public String password;
@Async
@Transient
@EventListener(classes = {UnisicBaseEvent.class}, condition = "T(com.unisic.base.shading.ConstantInterface).CREATE_INDEX.equals(#event.topic)")
public synchronized void procedureFinish(UnisicBaseEvent<List<String>> event) {
TimeUtil.delay(3);
List<String> tableList = event.getData();
String logicTableName = tableList.get(0);
String actualTableName = tableList.get(1);
String sqlQuery = "SELECT \n" +
" i.relname AS index_name,\n" +
" a.attname AS column_name\n" +
"FROM \n" +
" pg_index idx\n" +
"JOIN \n" +
" pg_class t ON t.oid = idx.indrelid\n" +
"JOIN \n" +
" pg_class i ON i.oid = idx.indexrelid\n" +
"JOIN \n" +
" pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(idx.indkey)\n" +
"WHERE \n" +
" t.relname = '" + logicTableName +"';";
// 用来存储索引和字段
Map<String,List<String>> indexMap = new HashMap<>();
try {
// DataSource dataSource = ApplicationContextUtil.getBean(DataSource.class);
// Connection connection = dataSource.getConnection();
// Statement statement = connection.createStatement();
if (connection == null || connection.isClosed()) {
connection = DriverManager.getConnection(url, username, password);
}
Statement statement = connection.createStatement();
ResultSet rs = statement.executeQuery(sqlQuery);
while (rs.next()) {
String index_name = rs.getString(1);
String column_name = rs.getString(2);
if(indexMap.containsKey(index_name)){
List<String> list = indexMap.get(index_name);
list.add(column_name);
}else{
ArrayList<String> list = new ArrayList<>();
list.add(column_name);
indexMap.put(index_name,list);
}
}
if(indexMap.size() > 0){
indexMap.keySet().forEach(element->{
List<String> list = indexMap.get(element);
String columnName = String.join(",", list);
String indexName = actualTableName + "_" + String.join("_", list);
String indexSql = "CREATE INDEX "+indexName +" ON "+ actualTableName +" ("+columnName+");";
try {
statement.executeUpdate(indexSql);
} catch (SQLException e) {
log.info("create-index-SQLException: " + e.getMessage());
throw new RuntimeException(e);
}
});
}
} catch (SQLException e) {
log.info("SQLException: " + e.getMessage());
throw new RuntimeException(e);
}
}
}