spring boot 项目中 mongodb 分库分表分片小记

 背景:

  项目是Saas化平台,由于大量非结构化数据存在MongoDb中,由于数据量增长,单表存放已经存在瓶颈。故而考虑优化数据存储方式提高效率

 分析:

  由于是多租户,暂定对租户进行分库,下面的表集合根据需要做分表或分片处理。

  实现:

1.根据租户分库 + 根据业务分表

mongo实体定义

@Data
@Document("ExampleObject")
@CompoundIndexes({
        @CompoundIndex(name = "idx_example_other_id", def = "{'otherId':1}"),
})
@AllArgsConstructor
@NoArgsConstructor
public class ExampleObject {
    @Id
    private String id;

    
    @Field
    @ApiModelProperty(value = "其他id",name = "otherId")
    private String otherId;
}

数据库切换切面实现如下:

@Aspect
@Component
@Order(1)
public class MongoSwitchAspect {
    private final Logger log = LoggerFactory.getLogger(this.getClass());
    private Map<String, MongoDbFactory> templateMultiMap = new HashMap<>();

    //获取配置文件的副本集连接
    @Value("${spring.data.mongodb.uri}")
    private String uri;

    @Value("${spring.data.mongodb.database}")
    private String database;

    @Pointcut("execution (* org.springframework.data.mongodb.core.MongoTemplate.*(..))")
    public void routeMongoDB() {

    }

    @Around("routeMongoDB()")
    public Object routeMongoDB(ProceedingJoinPoint joinPoint) {
        Object result = null;
        //获取需要访问的项目数据库
        String tenantId =null;
        if(!"main".equals(Thread.currentThread().getName())){
            tenantId = ContextUtil.getTenant();
        }
        try {
            String dbName = "";
            // 根据上下文租户信息切换数据库
            if (StringUtils.isEmpty(tenantId)) {
                dbName = database;
            } else {
                dbName = tenantId + "_" + database;
            }

            Object o = joinPoint.getTarget();
            MultiMongoTemplate mongoTemplate = null;


            SimpleMongoClientDbFactory simpleMongoClientDbFactory = (SimpleMongoClientDbFactory) templateMultiMap.get(dbName);
            //实例化
            if (simpleMongoClientDbFactory == null) {
                //替换数据源
                simpleMongoClientDbFactory = new SimpleMongoClientDbFactory(MongoClients.create(this.uri), dbName);
                templateMultiMap.put(dbName, simpleMongoClientDbFactory);
            }
            //如果第一次,赋值成自定义的MongoTemplate子类
            if (o.getClass() == MongoTemplate.class) {
                mongoTemplate = new MultiMongoTemplate(simpleMongoClientDbFactory);
            } else if (o.getClass() == MultiMongoTemplate.class) {
                mongoTemplate = (MultiMongoTemplate) o;
            }
            if (null == mongoTemplate) {
                mongoTemplate = new MultiMongoTemplate(simpleMongoClientDbFactory);
            }
            //设置MongoFactory
            mongoTemplate.setMongoDbFactory(simpleMongoClientDbFactory);

            result = joinPoint.proceed();
            //清理ThreadLocal的变量
            mongoTemplate.removeMongoDbFactory();
        } catch (Exception e) {
            log.error("", e);
        } catch (Throwable e) {
            e.printStackTrace();
        }

        return result;
    }
}

策略模式+切面针对指定集合进行代理,便于处理分表

@Aspect
@Component
@Slf4j
@Order(2)
public class MongoAspect {

    @Autowired
    private Map<String, MongoProxyStrategy> strategyMap;

    private Map<Class,String> typeMap=new HashMap<>();

    @PostConstruct
    public void init() throws NoSuchMethodException, NoSuchFieldException, IllegalAccessException {

        typeMap.put(ExampleObject.class, StrUtil.lowerFirst(ExampleObjectMongoProxy.class.getSimpleName()));

    }
    //代理 参数是 Query, Class的查询
    @Pointcut( "execution(public * org.springframework.data.mongodb.core.MongoTemplate.*(org.springframework.data.mongodb.core.query.Query,Class))")
    public void mongoFind() {

    }

    //代理 保存的数据是Identity的插入
    @Pointcut( "(execution(public * org.springframework.data.mongodb.core.MongoTemplate.insert(Object+)) || execution(public * org.springframework.data.mongodb.core.MongoTemplate.save(Object+)) )")
    public void mongoSave() {

    }

    //代理 保存的数据是Identity的批量插入
    @Pointcut("execution(public * org.springframework.data.mongodb.core.MongoTemplate.insertAll(java.util.Collection+))")
    public void mongoInsertAll(){}

    //代理 参数是Query,Update,Class,且Class 是 Identity的批量插入
    @Pointcut("execution(public * org.springframework.data.mongodb.core.MongoTemplate.update*(..)) " +
            "&& args(org.springframework.data.mongodb.core.query.Query,org.springframework.data.mongodb.core.query.Update,Class)")
    public void mongoUpdate(){}
    @Pointcut("execution(public * org.springframework.data.mongodb.core.MongoTemplate.aggregate*(..)) && args(org.springframework.data.mongodb.core.aggregation.Aggregation,Class,Class)")
    public void mongoAggregate(){}
    // 查询代理
    @Around("mongoFind()")
    public Object aroundFind(ProceedingJoinPoint joinPoint) throws Throwable {
        Object[] args = joinPoint.getArgs();
        Query query=(Query) args[0];
        Class clazz=(Class) args[1];
        if(!typeMap.containsKey(clazz)){
            return joinPoint.proceed();
        }
        return strategyMap.get(typeMap.get(clazz)).aroundFind(joinPoint,query,clazz);

    }

    // 单个插入代理
    @Around("mongoSave()")
    public Object aroundInsert(ProceedingJoinPoint joinPoint) throws Throwable {
        Object[] args = joinPoint.getArgs();
        Object entity=args[0];

        if (ObjectUtil.isEmpty(entity)){
            return null;
        }
        if(!typeMap.containsKey(entity.getClass())){
            return joinPoint.proceed();
        }
        return strategyMap.get(typeMap.get(entity.getClass())).aroundInsert(joinPoint,entity);
    }
    // 代理批量插入
    @Around("mongoInsertAll()")
    public Object aroundInsertAll(ProceedingJoinPoint joinPoint) throws Throwable {
        Object[] args = joinPoint.getArgs();
        Object entity=args[0];

        if (ObjectUtil.isEmpty(entity)){
            return null;
        }

        Collection<Object> collection=(Collection<Object>) entity;
        if(CollUtil.isEmpty(collection)){
            return null;
        }
        Class clazz=collection.iterator().next().getClass();

        if(!typeMap.containsKey(clazz)){
            return joinPoint.proceed();
        }
        return strategyMap.get(typeMap.get(clazz)).aroundInsertAll(joinPoint,collection);
    }


    @Around("mongoUpdate()")
    public Object aroundUpdate(ProceedingJoinPoint joinPoint) throws Throwable{
        Object[] args = joinPoint.getArgs();
        Class clazz=(Class) args[2];

        if(!typeMap.containsKey(clazz)){
            return joinPoint.proceed();
        }
       return  strategyMap.get(typeMap.get(clazz)).aroundUpdate(joinPoint,(Query) args[0],(Update)(args[1]));

    }
    @Around("mongoAggregate()")
    public Object aroundAggregate(ProceedingJoinPoint joinPoint) throws Throwable{
        Object[] args = joinPoint.getArgs();
        Class clazz=(Class) args[1];

        if(!typeMap.containsKey(clazz)){
            return joinPoint.proceed();
        }

        return  strategyMap.get(typeMap.get(clazz)).aroundAggregate(joinPoint,(Aggregation) args[0], (Class) args[2]);

    }
}

实现分表的具体实现


@Slf4j
@Component
@Order(2)
public class ExampleObjectMongoProxy extends MongoProxyStrategy {
    private final Logger log = LoggerFactory.getLogger(ExampleObjectMongoProxy.class);
    @Autowired
    private BusService busService;
    @Autowired
    private IRedisCommonService redisService;
    @Autowired
    private Map<String, MongoDbFactory> templateMultiMap = new HashMap<>();

    @Value("${spring.data.mongodb.uri}")
    private String uri;

    @Value("${spring.data.mongodb.database}")
    private String database;


    /**
     * 标识查询
     */
    public Object aroundFind(ProceedingJoinPoint joinPoint, Query query, Class clazz) throws Throwable {

        List<Long> busId = getOtherId(query);
        if (CollUtil.isEmpty(busId)) {
            log.info("当前ExampleObject查询不支持,Query: {} ", query.getQueryObject().toJson());
            throw new BizException("当前ExampleObject查询不支持分表,缺少busId参数");
        } else {
            if (isSingle(busId)) {
                //单表查询
                return singleFind(joinPoint.getSignature().getName(), query, busId.get(0));
            } else {
                //多表查询
                return mulFind(joinPoint.getSignature().getName(), query, busId);
            }
        }
    }

    /**
     * 标识添加
     */
    public Object aroundInsert(ProceedingJoinPoint joinPoint, Object o) {
        String methodName = joinPoint.getSignature().getName();
        ExampleObject exampleObject = (ExampleObject) o;
        if (ObjectUtil.isEmpty(exampleObject.getOtherId())) {
            log.error("分表失败,busId值为空");
            throw new BizException("不支持ExampleObject 分表查询");
        }
        if (methodName.equals("save")) {
            return mongoTemplate.save(exampleObject, getExampleObjectCollectionName(exampleObject.getOtherId(), false));
        } else if (methodName.equals("insert")) {
            return mongoTemplate.save(exampleObject, getExampleObjectCollectionName(exampleObject.getOtherId(), false));
        } else {
            return null;
        }
    }

    /**
     * 标识批量添加
     */
    public Object aroundInsertAll(ProceedingJoinPoint joinPoint, Collection<?> collection) throws Throwable {
        List<ExampleObject> exampleObjectList = collection.stream().map(item -> (ExampleObject) item).collect(Collectors.toList());

        if (exampleObjectList.stream().filter(exampleObject -> ObjectUtil.isNotEmpty(exampleObject.getOtherId())).count() != exampleObjectList.size()) {
            log.error("批量插入时,存在未携带BusId的ExampleObject");
            throw new BizException("批量插入时,存在未携带BusId的ExampleObject");
        }
        Map<Long, List<ExampleObject>> busIdToExampleObjectMap = exampleObjectList.stream().collect(Collectors.groupingBy(ExampleObject::getOtherId));
        List<Object> saveObject = new ArrayList<Object>();
        Map<String, List<ExampleObject>> collectionNameToMap = new HashMap<>();
        busIdToExampleObjectMap.forEach((busId, value) -> {
            collectionNameToMap.put(getExampleObjectCollectionName(busId, false), value);
        });
        Method doInsertBatch = mongoTemplateMethod.get("doInsertBatch");

        for (Map.Entry<String, List<ExampleObject>> entry : collectionNameToMap.entrySet()) {
            List insertAnsList = (List) doInsertBatch.invoke(mongoTemplate, entry.getKey(), entry.getValue(), mongoConverter);
            saveObject.addAll(insertAnsList);
        }
        return saveObject;
    }

    /**
     * 标识更新
     */
    public Object aroundUpdate(ProceedingJoinPoint joinPoint, Query query, Update update) throws Throwable {
        List<Long> busIds = getOtherId(query);
        if (CollUtil.isEmpty(busIds)) {
            log.error("不支持ExampleObject 分表查询,query: {}", query.getQueryObject().toJson());
            throw new BizException("不支持ExampleObject 分表查询");
        }
        String methodName = joinPoint.getSignature().getName();
        Method doUpdate = mongoTemplateMethod.get("doUpdate");
        List<BsonArray> updateResultList = new ArrayList<>();
        long matchedCount = 0L;
        long modifiedCount = 0L;
        if (methodName.equals("updateFirst")) {
            for (Long id : busIds) {
                UpdateResult ansUpdateList = (UpdateResult) doUpdate.invoke(mongoTemplate, getExampleObjectCollectionName(id, false), query, update, ExampleObject.class, false, true);
                if (ObjectUtil.isNotEmpty(ansUpdateList.getUpsertedId())) {
                    return ansUpdateList;
                }
                matchedCount += ansUpdateList.getMatchedCount();
                modifiedCount += ansUpdateList.getModifiedCount();
            }
            return UpdateResult.acknowledged(matchedCount, modifiedCount, new BsonArray());
        } else if (methodName.equals("updateMulti")) {
            for (Long id : busIds) {
                UpdateResult ansUpdateList = (UpdateResult) doUpdate.invoke(mongoTemplate, getExampleObjectCollectionName(id, false), query, update, ExampleObject.class, false, true);
                if (ObjectUtil.isNotEmpty(ansUpdateList.getUpsertedId())) {
                    updateResultList.add(ansUpdateList.getUpsertedId().asArray());
                }
                matchedCount += ansUpdateList.getMatchedCount();
                modifiedCount += ansUpdateList.getModifiedCount();
            }
            BsonArray bsonValues = new BsonArray();
            bsonValues.addAll(updateResultList);
            return UpdateResult.acknowledged(matchedCount, modifiedCount, bsonValues);

        } else {
            return null;
        }
    }

    /**
     *
     * @param joinPoint
     * @param aggregation 聚合条件
     * @param clazz 要输出的类型
     * @return
     * @throws Throwable
     */
    @Override
    public Object aroundAggregate(ProceedingJoinPoint joinPoint, Aggregation aggregation, Class clazz) throws Throwable {
        List<Long> busId = getOtherId(aggregation);
        if (CollUtil.isEmpty(busId)) {
            log.error("不支持ExampleObject 分表查询,query: {}", aggregation.toString());
            throw new BizException("不支持ExampleObject 分表查询");
        }
        //单表查询
        if (busId.size() == 1) {
            return mongoTemplateMethod.get("aggregate").invoke(mongoTemplate, aggregation, getExampleObjectCollectionName(busId.get(0), true), clazz);
        }
        //夸表查询
        else {
            return mulAggregate(aggregation, busId, clazz);
        }
    }

    /**
     * 创建索引
     */
    @PostConstruct
    protected void initIndex() {

        indexBuss = new ArrayList<>();

        {
            CompoundIndexes annotation = ExampleObject.class.getAnnotation(CompoundIndexes.class);
            CompoundIndex[] value = annotation.value();
            for (CompoundIndex index : value) {
                IndexBus indexBus = new IndexBus(new Document(JSON.parseObject(index.def(),Map.class)),
                        new IndexOptions()
                                .name(index.name())
                                .unique(index.unique())
                                .sparse(index.sparse())
                );
                indexBuss.add(indexBus);
            }
        }
    }
    /**
     * 更新索引
     */
    private void updateIndex(String collectionName) {
        if (ObjectUtil.isEmpty(indexBuss)) {
            initIndex();

        }

        Map<String, IndexBus> indexBusMap = indexBuss.stream().collect(Collectors.toMap(i -> i.getOptions().getName(), i -> i, (a, b) -> a));
        Set<String> indexNames = indexBusMap.keySet();
        //计算索引数据
        IndexOperations operations = mongoTemplate.indexOps(collectionName);
        String dbName = ContextUtil.getTenant() + "_" + database;
        if (templateMultiMap.containsKey(dbName)) {
            ((MultiMongoTemplate) mongoTemplate).setMongoDbFactory(templateMultiMap.get(dbName));

        } else {
            SimpleMongoClientDbFactory simpleMongoClientDbFactory = new SimpleMongoClientDbFactory(MongoClients.create(this.uri), dbName);
            templateMultiMap.put(dbName, simpleMongoClientDbFactory);
            ((MultiMongoTemplate) mongoTemplate).setMongoDbFactory(simpleMongoClientDbFactory);
        }


        List<IndexInfo> indexInfo = operations.getIndexInfo();
        Map<String, IndexInfo> indexMap = indexInfo.stream().collect(Collectors.toMap(IndexInfo::getName, i -> i));
        List<IndexBus> createIndexList = new ArrayList<>();
        //计算该表缺失的索引
        indexNames.forEach(name -> {
                    if (!indexMap.containsKey(name)) {
                        createIndexList.add(indexBusMap.get(name));
                    }
                }
        );
        if (CollUtil.isNotEmpty(createIndexList)) {
            mongoTemplate.getCollection(collectionName).createIndexes(createIndexList);
            log.info("更新 collection {} 索引成功", collectionName);
        }

        log.info("计算索引结束");
    }

    /**
     * 从查询条件中,过滤出busId
     * @param query
     * @return
     */
    private List<Long> getOtherId(Query query) {

        Document doc = query.getQueryObject();

        List<Long> busIds = getBusWithDoc(doc);

        if (CollUtil.isEmpty(busIds)) {
            log.info("当前不支持查询 Query:{}", doc.toJson());
            throw new BizException("当前不支持查询 Query:缺少busId参数");
        }
        return busIds.stream().distinct().collect(Collectors.toList());
    }

    /**
     *  从聚合条件中,过滤出busId
     * @param aggregation
     * @return
     */
    private List<Long> getOtherId(Aggregation aggregation) {
        List<Document> pipeline = aggregation.toPipeline(Aggregation.DEFAULT_CONTEXT);
        Optional<Document> optional = pipeline.stream().filter(item -> ObjectUtil.isNotEmpty(item.get("$match"))).findFirst();
        if (!optional.isPresent()) {
            throw new BizException("获取查询语句失败");
        }
        Document doc = optional.get().get("$match", Document.class);
        List<Long> busIds = getBusWithDoc(doc);
        if (CollUtil.isEmpty(busIds)) {
            log.info("当前不支持查询 Query:{}", doc);
            throw new BizException("当前不支持查询 Query:缺少busId参数");
        }
        return busIds.stream().distinct().collect(Collectors.toList());
    }

    /**
     *  将所有条件转成BSON形式,并获取busId
     * @param doc
     * @return
     */

    private List<Long> getBusWithDoc(Document doc) {
        List<Long> busIds = new ArrayList<>();
        doc.forEach((key, value) -> {
            if (key.equals("$and")) {
                if (value instanceof List) {
                    ((List<Document>) value).forEach(item -> busIds.addAll(getBusWithDoc(item)));
                } else if (value instanceof Document) {
                    busIds.addAll(getBusWithDoc((Document) value));
                }

            } else if (key.equals("busId")) {
                if (value instanceof Long) {
                    busIds.add((Long) value);
                } else if (value instanceof Document) {
                    Document idDoc = (Document) value;
                    if (idDoc.containsKey("$in")) {
                        busIds.addAll(idDoc.getList("$in", Long.class));
                    }
                }
            }

        });
        return busIds;
    }

    /**
     *  获取需要的表名
     * @param busId id
     * @param isQuery 是不是查询语句
     * @return
     */
    private String getExampleObjectCollectionName(Long busId, boolean isQuery) {
//        String tenantId = ContextUtil.getTenantId().toString();
//        String key = String.format("busCode_cache:%s:%s", tenantId, busId.toString());
//        Object o = redisService.get(key);
//        if (ObjectUtil.isNotEmpty(o)) {
//            updateIndex(o.toString());
//            return o.toString();
//        }
        Bus bus = busService.getByIdFromCache(busId);
        if (ObjectUtil.isEmpty(bus) || ObjectUtil.isEmpty(bus.getCode())) {
            log.info("查询不到模型 {}", busId);
            return "exampleObject_undefine";
        }

        String collectionName = "exampleObject_$busCode$".replace("$busCode$", bus.getCode());

        //判断是否创建了表,没有就创建,同时设置索引
        if (!isQuery && !mongoTemplate.collectionExists(collectionName)) {
            createCollectionName(collectionName);
        }
        return collectionName;
    }

    /**
     *
     * @param collectionName 创建指定的表名
     */
    @Override
    protected void createCollectionName(String collectionName) {
        if (ObjectUtil.isEmpty(indexBuss)) {
            initIndex();
        }
        mongoTemplate.createCollection(collectionName).createIndexes(indexBuss);
    }

    /**
     *  是不是单表查询
     * @param ids
     * @return
     */

    private boolean isSingle(List<Long> ids) {
        return ids.size() == 1;
    }

    /**
     *  判断是不是分页查询
     * @param query 查询条件
     */
    private boolean isPage(Query query) {
        long skip = query.getSkip();
        int limit = query.getLimit();
        return (skip | limit) != 0;
    }

    /**
     *  判断是不是分页查询
     * @param aggregation 聚合条件
     */
    private boolean isPage(Aggregation aggregation) {
        List<Document> pipeline = aggregation.toPipeline(Aggregation.DEFAULT_CONTEXT);
        Optional<Document> skip = pipeline.stream().filter(item -> item.containsKey("$skip")).findFirst();
        Optional<Document> limit = pipeline.stream().filter(item -> item.containsKey("$limit")).findFirst();
        return limit.isPresent() && skip.isPresent();

    }

    /**
     *  单表查询
     * @param methodName 方法名
     * @param query 查询条件
     * @param id 模型ID
     * @return
     * @throws InvocationTargetException
     * @throws IllegalAccessException
     */
    private Object singleFind(String methodName, Query query, Long id) throws InvocationTargetException, IllegalAccessException {
        return mongoTemplateMethod.get(methodName).invoke(mongoTemplate, query, ExampleObject.class, getExampleObjectCollectionName(id, true));
    }
    /**
     *  多表查询
     * @param methodName 方法名
     * @param query 查询条件
     * @param busId 模型ID
     * @return
     * @throws InvocationTargetException
     * @throws IllegalAccessException
     */
    private Object mulFind(String methodName, Query query, List<Long> busId) {
        if (!isPage(query)) {
            if (methodName.equals("count")) {
                int count = 0;
                for (Long id : busId) {
                    count += mongoTemplate.count(query, ExampleObject.class, getExampleObjectCollectionName(id, true));
                }
                return count;

            } else if (methodName.equals("exists")) {
                boolean ans = false;
                for (Long id : busId) {
                    ans = mongoTemplate.exists(query, ExampleObject.class, getExampleObjectCollectionName(id, true));
                    if (ans) {
                        return ans;
                    }
                }
                return false;
            } else if (methodName.equals("findOne")) {

                for (Long id : busId) {
                    ExampleObject ans = mongoTemplate.findOne(query, ExampleObject.class, getExampleObjectCollectionName(id, true));
                    if (ans != null) {
                        return ans;
                    }
                }
                return null;
            } else if (methodName.equals("find")) {
                List<ExampleObject> ansList = new ArrayList<>();
                for (Long id : busId) {
                    ansList.addAll(mongoTemplate.find(query, ExampleObject.class, getExampleObjectCollectionName(id, true)));
                }
                return ansList;
            } 
            // 拓展其他方法
            return null;
        }
        // 只针对分页查询
        else {
            int nowCount = query.getLimit();
            List<ExampleObject> ans = new ArrayList<>();
            for (Long id : busId) {
                Query queryNoPage = Query.of(query);
                queryNoPage.skip(0L).limit(0);
                long count = mongoTemplate.count(queryNoPage, ExampleObject.class, getExampleObjectCollectionName(id, true));
                if (count == 0) {
                    continue;
                } else {
                    long skip = query.getSkip();
                    int limit = query.getLimit();
                    if (skip >= count) {
                        query.skip(skip - count);
                    } else {
                        List<ExampleObject> exampleObjects = mongoTemplate.find(query, ExampleObject.class, getExampleObjectCollectionName(id, true));
                        ans.addAll(exampleObjects);
                        if (exampleObjects.size() < limit) {
                            query.limit(limit - exampleObjects.size());
                        }
                        query.skip(0);

                    }
                }
                if (ans.size() >= nowCount) {
                    return ans;
                }
            }
            return ans;
        }
    }

    /**
     *  多表聚合查询
     * @param aggregation 聚合条件
     * @param busIds busIDs
     * @param clazz 输出类型
     * @return
     * @throws InvocationTargetException
     * @throws IllegalAccessException
     */
    private Object mulAggregate(Aggregation aggregation, List<Long> busIds, Class clazz)  {
        List mappedResultsList = new ArrayList();
        List<AggregationResults> list = new ArrayList<>();
        if (!isPage(aggregation)) {
            for (Long busId : busIds) {
                AggregationResults results = mongoTemplate.aggregate(aggregation, getExampleObjectCollectionName(busId, true), clazz);
                list.add(results);
            }
            Document document = new Document();
            document.put("ok", 1.0);
            long total=0;
            //该类型执行,获取总数
            if (Page.class == clazz){
                for (AggregationResults item : list) {
                    if(CollUtil.isNotEmpty(item.getMappedResults())){
                        total+= ((Page) item.getMappedResults().get(0)).getTotal();
                    }

                }
                Page<Object> objectPage = new Page<>();
                objectPage.setTotal(total);
                document.put("results",objectPage);
                return new AggregationResults(ListUtil.list(false,objectPage), document);
            //查询标识
            }else if(ExampleObject.class ==clazz){
                list.forEach(item -> {
                    mappedResultsList.addAll(item.getMappedResults());
                    if (!item.getRawResults().getDouble("ok").equals(1.0)) {
                        document.put("ok", 0.0);
                    }
                });
                document.put("results", mappedResultsList);
                return new AggregationResults(mappedResultsList, document);
            }
            throw new BizException("不支持输出类型的分表查询");
        } else {
            //移除分页信息,获取查询总数的查询条件
            List<Document> pipeline = aggregation.toPipeline(Aggregation.DEFAULT_CONTEXT);
            Optional<Document> skipDoc = pipeline.stream().filter(item -> item.containsKey("$skip")).findFirst();
            Optional<Document> limitDoc = pipeline.stream().filter(item -> item.containsKey("$limit")).findFirst();
            Long skip = 0L;
            if (skipDoc.isPresent()) {
                skip = skipDoc.get().getLong("$skip");
            }
            Long limit = 10L;
            if (limitDoc.isPresent()) {
                limit = limitDoc.get().getLong("$limit");
            }

            List<Document> notPagePipeline = pipeline.stream()
                    .filter(item -> !(item.containsKey("$skip") || item.containsKey("$limit")))
                    .collect(Collectors.toList());
            //查询结果
            List ansList = new LinkedList<>();
            for (Long busId : busIds) {
                //计算本次查询到的数量,下个集合的查询条件
                AggregateRecord aggregateRecord = buildPipeline(notPagePipeline, skip, limit, busId, clazz, ansList);
                if (aggregateRecord.skip >= aggregateRecord.count) {
                    skip = aggregateRecord.skip;
                } else {
                    if (aggregateRecord.limit == 0L) {
                        break;
                    }
                    limit = aggregateRecord.limit;
                    skip = aggregateRecord.skip;
                }
            }
            Document document = new Document();
            document.put("ok", 1.0);
            document.put("results", ansList);
            return new AggregationResults(ansList, document);
        }

    }

    private AggregateRecord buildPipeline(List<Document> notPagePipeline, Long skip, Long limit, Long busId, Class clazz, List list) {

        //先统计当前的符合条件的数量
        List<AggregationOperation> defaultpageOption = new ArrayList<>();
        for (Document document : notPagePipeline) {
            AggregationOperation operation = aoc -> document;
            defaultpageOption.add((operation));
        }
        List<AggregationOperation> notPageOption = new ArrayList<>();
        notPageOption.addAll(defaultpageOption);
        notPageOption.add(Aggregation.count().as("total"));
        notPageOption.add(Aggregation.project("total"));

        AggregationResults<Page> countResult = mongoTemplate.aggregate(Aggregation.newAggregation(notPageOption), getExampleObjectCollectionName(busId, true), Page.class);
        Long totalCount = ObjectUtil.isEmpty(countResult) ? 0 : countResult.getMappedResults().get(0).getTotal();
        //跳过数量大于总数,直接返回,该集合无需查询
        if (skip >= totalCount) {
            return new AggregateRecord(0L, skip - totalCount, limit);
        }

        List<AggregationOperation> pageOption = new ArrayList<>();
        pageOption.addAll(defaultpageOption);
        pageOption.add(Aggregation.skip(skip));
        pageOption.add(Aggregation.limit(limit));

        AggregationResults results = mongoTemplate.aggregate(Aggregation.newAggregation(pageOption), getExampleObjectCollectionName(busId, false), clazz);

        long count = (long) results.getMappedResults().size();
        //如果查询的数量小于limit,则直接返回
        if (limit > results.getMappedResults().size()) {
            list.addAll(results.getMappedResults());

            return new AggregateRecord(count, 0L, limit - count);
         // 如果查询的数量大于limit,证明查询完毕。
        } else {
            list.addAll(results.getMappedResults().subList(0, (int) (limit - 1)));
            skip = skip - totalCount <= 0 ? 0 : skip - totalCount + count;
            return new AggregateRecord(count, skip, 0L);
        }


    }

    @AllArgsConstructor
    private class AggregateRecord {
        public Long count;
        public Long skip;
        private Long limit;
    }
}

2. 分片实现(大概步骤)

  先搭建分片集群,然后注意的点是,对指定表的分片要注意选择分片键(了解hash和范围),避免出现数据倾斜。创建分配键前,需创建对应索引。

  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MongoDB是一个NoSQL数据库,它的分库分表策略与传统关系型数据库有所不同。在MongoDB,可以通过分片(sharding)来实现数据的分布式存储和水平扩展。 分库分表的主要目的是解决数据量过大、性能瓶颈等问题。下面是一种常见的MongoDB分库分表方案: 1. 分片键选择:首先需要选择一个合适的分片键。分片键是用来对数据进行划分的字段,它应该具备均匀分布的特点,以便实现数据的平衡分布和查询的高效。 2. 配置分片集群:在MongoDB,一个分片集群由多个分片服务器(shard server)组成,以及一些用于管理和路由数据的配置服务器(config server)组成。每个分片服务器负责存储一部分数据,并处理相关的查询请求。 3. 启用分片:在配置好分片集群后,需要启用分片功能,将数据进行划分并存储到不同的分片服务器上。可以使用MongoDB提供的`sh.shardCollection()`命令或者MongoDB Atlas的自动分片功能来完成这一步骤。 4. 选择合适的分布策略:MongoDB提供了多种分布策略,可以根据具体需求选择适合的策略。常见的分布策略包括范围分片、哈希分片和标签分片等。 5. 数据迁移和重平衡:在数据量增长或分布不均衡时,可能需要进行数据迁移和重平衡操作,以保证数据的均衡分布和查询的高效。 需要注意的是,MongoDB分片集群需要进行一定的规划和配置,并且分片集群的部署和管理相对复杂。在实际应用,需要根据具体的业务需求和数据特点来选择合适的分库分表方案,并考虑数据一致性、故障恢复、备份恢复等因素。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值