通过修改sharding分片规则,实现让某些表请求值定到某个db上
springSharding 版本
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>shardingsphere-jdbc-core-spring-boot-starter</artifactId>
<version>5.1.2</version>
</dependency>
修改对应的添加分片规则和where分片规则
修改sharding包 org.apache.shardingsphere.sharding.route.engine.condition.engine.impl
下的 InsertClauseShardingConditionEngine 和 WhereClauseShardingConditionEngine 类的createShardingConditions来实现
完整代码
NoRuleShardingCondition
package org.apache.shardingsphere.sharding.route.engine.condition.engine.impl;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* @author kittlen
* @date 2024-04-02 15:54
* @description
*/
public class NoRuleShardingCondition extends ShardingCondition {
private static final List<String> collect = Stream.of("@noRule").collect(Collectors.toList());
public NoRuleShardingCondition(String tableName) {
super.setStartIndex(0);
super.getValues().add(new ListShardingConditionValue<>("@noRule", tableName, collect));
}
private static Map<String, NoRuleShardingCondition> noRuleMap = new HashMap<>();
public static NoRuleShardingCondition buildNotRuleShardingCondition(String tableName) {
NoRuleShardingCondition noRuleShardingCondition = noRuleMap.get(tableName);
if (noRuleShardingCondition != null) {
return noRuleShardingCondition;
}
noRuleShardingCondition = new NoRuleShardingCondition(tableName);
noRuleMap.put(tableName, noRuleShardingCondition);
return noRuleShardingCondition;
}
}
WhereClauseShardingConditionEngine
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sharding.route.engine.condition.engine.impl;
import com.google.common.collect.Range;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.type.DatabaseTypeEngine;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.decorator.model.ShardingSphereSchema;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.ComplexShardingStrategyConfiguration;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.ShardingStrategyConfiguration;
import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.Column;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.engine.ShardingConditionEngine;
import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
import org.apache.shardingsphere.sharding.route.engine.condition.value.AlwaysFalseShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.rule.TableRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtil;
import org.apache.shardingsphere.sql.parser.sql.common.util.SafeNumberOperationUtil;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
/**
* Sharding condition engine for where clause.
*/
@RequiredArgsConstructor
public final class WhereClauseShardingConditionEngine implements ShardingConditionEngine<SQLStatementContext<?>> {
private final ShardingRule shardingRule;
private final ShardingSphereDatabase database;
@Override
public List<ShardingCondition> createShardingConditions(final SQLStatementContext<?> sqlStatementContext, final List<Object> parameters) {
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
List<ShardingCondition> byAll = findByNoRule(sqlStatementContext, parameters);
if (CollectionUtils.isEmpty(byAll)) {
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
String defaultSchemaName = DatabaseTypeEngine.getDefaultSchemaName(sqlStatementContext.getDatabaseType(), database.getName());
ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName()
.map(optional -> database.getSchemas().get(optional)).orElseGet(() -> database.getSchemas().get(defaultSchemaName));
Map<String, String> columnExpressionTableNames = sqlStatementContext.getTablesContext().findTableNamesByColumnSegment(columnSegments, schema);
List<ShardingCondition> result = new ArrayList<>();
for (WhereSegment each : ((WhereAvailable) sqlStatementContext).getWhereSegments()) {
result.addAll(createShardingConditions(each.getExpr(), parameters, columnExpressionTableNames));
}
return result;
} else {
return byAll;
}
}
private List<ShardingCondition> findByNoRule(final SQLStatementContext sqlStatementContext, final List<Object> parameters) {
List<ShardingCondition> result = new ArrayList<>();
Collection<ShardingRule> rules = database.getRuleMetaData().findRules(ShardingRule.class);
if (CollectionUtils.isEmpty(rules)) {
return result;
}
Collection<String> tableNames = sqlStatementContext.getTablesContext().getTableNames();
if (CollectionUtils.isEmpty(tableNames)) {
return result;
}
for (ShardingRule rule : rules) {
for (String tableName : tableNames) {
TableRule tableRule = rule.getTableRule(tableName);
if (tableRule != null) {
ShardingStrategyConfiguration databaseShardingStrategyConfig = tableRule.getDatabaseShardingStrategyConfig();
if (databaseShardingStrategyConfig != null && databaseShardingStrategyConfig instanceof ComplexShardingStrategyConfiguration) {
String shardingColumns = ((ComplexShardingStrategyConfiguration) databaseShardingStrategyConfig).getShardingColumns();
if ("@noRule".equals(shardingColumns)) {
result.add(NoRuleShardingCondition.buildNotRuleShardingCondition(tableName));
} else {
///只要其中有一个是有规则的,就返回空规则列表,交由sharding自己处理规则
return new ArrayList<>(0);
}
}
}
}
}
return result;
}
private Collection<ShardingCondition> createShardingConditions(final ExpressionSegment expression, final List<Object> parameters, final Map<String, String> columnExpressionTableNames) {
Collection<AndPredicate> andPredicates = ExpressionExtractUtil.getAndPredicates(expression);
Collection<ShardingCondition> result = new LinkedList<>();
for (AndPredicate each : andPredicates) {
Map<Column, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), parameters, columnExpressionTableNames);
if (shardingConditionValues.isEmpty()) {
return Collections.emptyList();
}
ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
// TODO remove startIndex when federation has perfect support for subquery
shardingCondition.setStartIndex(expression.getStartIndex());
result.add(shardingCondition);
}
return result;
}
private Map<Column, Collection<ShardingConditionValue>> createShardingConditionValueMap(final Collection<ExpressionSegment> predicates,
final List<Object> parameters, final Map<String, String> columnTableNames) {
Map<Column, Collection<ShardingConditionValue>> result = new HashMap<>(predicates.size(), 1);
for (ExpressionSegment each : predicates) {
for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
Optional<String> tableName = Optional.ofNullable(columnTableNames.get(columnSegment.getExpression()));
Optional<String> shardingColumn = tableName.flatMap(optional -> shardingRule.findShardingColumn(columnSegment.getIdentifier().getValue(), optional));
if (!tableName.isPresent() || !shardingColumn.isPresent()) {
continue;
}
Column column = new Column(shardingColumn.get(), tableName.get());
Optional<ShardingConditionValue> shardingConditionValue = ConditionValueGeneratorFactory.generate(each, column, parameters);
if (!shardingConditionValue.isPresent()) {
continue;
}
result.computeIfAbsent(column, unused -> new LinkedList<>()).add(shardingConditionValue.get());
}
}
return result;
}
private ShardingCondition createShardingCondition(final Map<Column, Collection<ShardingConditionValue>> shardingConditionValues) {
ShardingCondition result = new ShardingCondition();
for (Entry<Column, Collection<ShardingConditionValue>> entry : shardingConditionValues.entrySet()) {
try {
ShardingConditionValue shardingConditionValue = mergeShardingConditionValues(entry.getKey(), entry.getValue());
if (shardingConditionValue instanceof AlwaysFalseShardingConditionValue) {
return new AlwaysFalseShardingCondition();
}
result.getValues().add(shardingConditionValue);
} catch (final ClassCastException ex) {
throw new ShardingSphereException("Found different types for sharding value `%s`.", entry.getKey());
}
}
return result;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private ShardingConditionValue mergeShardingConditionValues(final Column column, final Collection<ShardingConditionValue> shardingConditionValues) {
Collection<Comparable<?>> listValue = null;
Range<Comparable<?>> rangeValue = null;
for (ShardingConditionValue each : shardingConditionValues) {
if (each instanceof ListShardingConditionValue) {
listValue = mergeListShardingValues(((ListShardingConditionValue) each).getValues(), listValue);
if (listValue.isEmpty()) {
return new AlwaysFalseShardingConditionValue();
}
} else if (each instanceof RangeShardingConditionValue) {
try {
rangeValue = mergeRangeShardingValues(((RangeShardingConditionValue) each).getValueRange(), rangeValue);
} catch (final IllegalArgumentException ex) {
return new AlwaysFalseShardingConditionValue();
}
}
}
if (null == listValue) {
return new RangeShardingConditionValue<>(column.getName(), column.getTableName(), rangeValue);
}
if (null == rangeValue) {
return new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue);
}
listValue = mergeListAndRangeShardingValues(listValue, rangeValue);
return listValue.isEmpty() ? new AlwaysFalseShardingConditionValue() : new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue);
}
private Collection<Comparable<?>> mergeListShardingValues(final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
if (null == value2) {
return value1;
}
value1.retainAll(value2);
return value1;
}
private Range<Comparable<?>> mergeRangeShardingValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
return null == value2 ? value1 : SafeNumberOperationUtil.safeIntersection(value1, value2);
}
private Collection<Comparable<?>> mergeListAndRangeShardingValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
Collection<Comparable<?>> result = new LinkedList<>();
for (Comparable<?> each : listValue) {
if (SafeNumberOperationUtil.safeContains(rangeValue, each)) {
result.add(each);
}
}
return result;
}
}
InsertClauseShardingConditionEngine
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sharding.route.engine.condition.engine.impl;
import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.infra.binder.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.datetime.DatetimeService;
import org.apache.shardingsphere.infra.datetime.DatetimeServiceFactory;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.ComplexShardingStrategyConfiguration;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.ShardingStrategyConfiguration;
import org.apache.shardingsphere.sharding.route.engine.condition.ExpressionConditionUtils;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.engine.ShardingConditionEngine;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.rule.TableRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.SimpleExpressionSegment;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Sharding condition engine for insert clause.
*/
@RequiredArgsConstructor
public final class InsertClauseShardingConditionEngine implements ShardingConditionEngine<InsertStatementContext> {
private final ShardingRule shardingRule;
private final ShardingSphereDatabase database;
@Override
public List<ShardingCondition> createShardingConditions(final InsertStatementContext sqlStatementContext, final List<Object> parameters) {
List<ShardingCondition> result = null == sqlStatementContext.getInsertSelectContext()
? createShardingConditionsWithInsertValues(sqlStatementContext, parameters)
: createShardingConditionsWithInsertSelect(sqlStatementContext, parameters);
appendGeneratedKeyConditions(sqlStatementContext, result);
return result;
}
private List<ShardingCondition> createShardingConditionsWithInsertValues(final InsertStatementContext sqlStatementContext, final List<Object> parameters) {
String tableName = sqlStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
List<ShardingCondition> byAll = findByNoRule(tableName, sqlStatementContext, parameters);
if (CollectionUtils.isEmpty(byAll)) {
Collection<String> columnNames = getColumnNames(sqlStatementContext);
List<InsertValueContext> insertValueContexts = sqlStatementContext.getInsertValueContexts();
List<ShardingCondition> result = new ArrayList<>(insertValueContexts.size());
for (InsertValueContext each : insertValueContexts) {
result.add(createShardingCondition(tableName, columnNames.iterator(), each, parameters));
}
return result;
} else {
return byAll;
}
}
private List<ShardingCondition> findByNoRule(String tableName, final InsertStatementContext sqlStatementContext, final List<Object> parameters) {
List<ShardingCondition> result = new ArrayList<>();
Collection<ShardingRule> rules = database.getRuleMetaData().findRules(ShardingRule.class);
if (CollectionUtils.isEmpty(rules)) {
return result;
}
for (ShardingRule rule : rules) {
TableRule tableRule = rule.getTableRule(tableName);
if (tableRule != null) {
ShardingStrategyConfiguration databaseShardingStrategyConfig = tableRule.getDatabaseShardingStrategyConfig();
if (databaseShardingStrategyConfig != null && databaseShardingStrategyConfig instanceof ComplexShardingStrategyConfiguration) {
String shardingColumns = ((ComplexShardingStrategyConfiguration) databaseShardingStrategyConfig).getShardingColumns();
if ("@noRule".equals(shardingColumns)) {
result.add(NoRuleShardingCondition.buildNotRuleShardingCondition(tableName));
} else {
///只要其中有一个是有规则的,就返回空规则列表,交由sharding自己处理规则
return new ArrayList<>(0);
}
}
}
}
return result;
}
private Collection<String> getColumnNames(final InsertStatementContext insertStatementContext) {
Optional<GeneratedKeyContext> generatedKey = insertStatementContext.getGeneratedKeyContext();
if (generatedKey.isPresent() && generatedKey.get().isGenerated()) {
Collection<String> result = new LinkedList<>(insertStatementContext.getColumnNames());
result.remove(generatedKey.get().getColumnName());
return result;
}
return insertStatementContext.getColumnNames();
}
private ShardingCondition createShardingCondition(final String tableName, final Iterator<String> columnNames, final InsertValueContext insertValueContext, final List<Object> parameters) {
ShardingCondition result = new ShardingCondition();
DatetimeService datetimeService = null;
for (ExpressionSegment each : insertValueContext.getValueExpressions()) {
Optional<String> shardingColumn = shardingRule.findShardingColumn(columnNames.next(), tableName);
if (!shardingColumn.isPresent()) {
continue;
}
if (each instanceof SimpleExpressionSegment) {
result.getValues().add(new ListShardingConditionValue<>(shardingColumn.get(), tableName, Collections.singletonList(getShardingValue((SimpleExpressionSegment) each, parameters))));
} else if (ExpressionConditionUtils.isNowExpression(each)) {
if (null == datetimeService) {
datetimeService = DatetimeServiceFactory.getInstance();
}
result.getValues().add(new ListShardingConditionValue<>(shardingColumn.get(), tableName, Collections.singletonList(datetimeService.getDatetime())));
} else if (ExpressionConditionUtils.isNullExpression(each)) {
throw new ShardingSphereException("Insert clause sharding column can't be null.");
}
}
return result;
}
@SuppressWarnings("rawtypes")
private Comparable<?> getShardingValue(final SimpleExpressionSegment expressionSegment, final List<Object> parameters) {
Object result;
if (expressionSegment instanceof ParameterMarkerExpressionSegment) {
result = parameters.get(((ParameterMarkerExpressionSegment) expressionSegment).getParameterMarkerIndex());
} else {
result = ((LiteralExpressionSegment) expressionSegment).getLiterals();
}
Preconditions.checkArgument(result instanceof Comparable, "Sharding value must implements Comparable.");
return (Comparable) result;
}
private List<ShardingCondition> createShardingConditionsWithInsertSelect(final InsertStatementContext sqlStatementContext, final List<Object> parameters) {
SelectStatementContext selectStatementContext = sqlStatementContext.getInsertSelectContext().getSelectStatementContext();
return new LinkedList<>(new WhereClauseShardingConditionEngine(shardingRule, database).createShardingConditions(selectStatementContext, parameters));
}
private void appendGeneratedKeyConditions(final InsertStatementContext sqlStatementContext, final List<ShardingCondition> shardingConditions) {
Optional<GeneratedKeyContext> generatedKey = sqlStatementContext.getGeneratedKeyContext();
String tableName = sqlStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
if (generatedKey.isPresent() && generatedKey.get().isGenerated() && shardingRule.findTableRule(tableName).isPresent()) {
generatedKey.get().getGeneratedValues().addAll(generateKeys(tableName, sqlStatementContext.getValueListCount()));
if (shardingRule.findShardingColumn(generatedKey.get().getColumnName(), tableName).isPresent()) {
appendGeneratedKeyCondition(generatedKey.get(), tableName, shardingConditions);
}
}
}
private Collection<Comparable<?>> generateKeys(final String tableName, final int valueListCount) {
return IntStream.range(0, valueListCount).mapToObj(each -> shardingRule.generateKey(tableName)).collect(Collectors.toList());
}
private void appendGeneratedKeyCondition(final GeneratedKeyContext generatedKey, final String tableName, final List<ShardingCondition> shardingConditions) {
Iterator<Comparable<?>> generatedValuesIterator = generatedKey.getGeneratedValues().iterator();
for (ShardingCondition each : shardingConditions) {
each.getValues().add(new ListShardingConditionValue<>(generatedKey.getColumnName(), tableName, Collections.<Comparable<?>>singletonList(generatedValuesIterator.next())));
}
}
}
sharding路由定义
spring:
shardingsphere:
rules:
table:
tb_partition_info: #会员分区,只分到第一个库
actualDataNodes: ${mysql.actual-data-nodes}.tb_test
databaseStrategy:
complex:
shardingColumns: '@noRule'
shardingAlgorithmName: snowflake-database
shardingAlgorithms:
snowflake-database: #此处命名只能用-的写法,不能使用驼峰写法
type: my-snowflake #使用的是自定义的分片规则
在sharding路由my-snowflake中处理对应的信息
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.sharding.api.sharding.complex.ComplexKeysShardingAlgorithm;
import org.apache.shardingsphere.sharding.api.sharding.complex.ComplexKeysShardingValue;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
/**
* @author kittlen
* @version 1.0
* @date 2023/01/17 9:53
*/
@Slf4j
public class SnowflakeShardingAlgorithm implements ComplexKeysShardingAlgorithm<String> {
@Override
public Collection<String> doSharding(Collection<String> collection, ComplexKeysShardingValue<String> complexKeysShardingValue) {
///如果参数包含有@noRule,则路由到默认节点(或再定义路由)
if (complexKeysShardingValue.getColumnNameAndShardingValuesMap().containsKey("@noRule")) {
return getDefNotRuleNodes();
}
//...
//其他的sharding分片规则
}
@Override
public Properties getProps() {
return this.properties;
}
@Override
public void init(Properties properties) {
this.properties = properties;
}
@Override
public String getType() {
return "my-snowflake";
}
}