Springboot生成树工具类,可通过 id/code 编码生成 2.0版本

  • 优化工具类中,查询父级时便利多次的问题

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.mutable.MutableLong;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * 树结构工具类
 *
 * @author
 * @date 2024/7/16 下午4:58
 * @description 提供树结构的构建、查询、转换等功能
 */
@SuppressWarnings("unused")
public class TreeUtil {

    /**
     * 使用 ParentId 构建树结构,适用于大数据量,避免使用递归,提高性能。
     *
     * @param list        所有节点的列表
     * @param getId       获取节点 ID 的函数
     * @param getParentId 获取父节点 ID 的函数
     * @param comparator  同级节点排序的比较器(可选)
     * @param setSub      设置子节点列表的函数
     * @param <T>         节点类型
     * @param <I>         节点 ID 类型
     * @return 树的根节点列表
     */
    public static <T, I> List<T> buildByParentId(
            @NonNull List<T> list,
            @NonNull Function<T, I> getId,
            @NonNull Function<T, I> getParentId,
            @Nullable Comparator<T> comparator,
            @NonNull BiConsumer<T, List<T>> setSub) {

        // 1. 构建 ID 到节点的映射,方便快速查找节点
        Map<I, T> idNodeMap = list.stream()
                .collect(Collectors.toMap(getId, Function.identity(), (existing, replacement) -> existing));

        // 2. 构建父 ID 到子节点列表的映射
        Map<I, List<T>> parentIdMap = new HashMap<>();
        for (T node : list) {
            I parentId = getParentId.apply(node);
            parentIdMap.computeIfAbsent(parentId, k -> new ArrayList<>()).add(node);
        }

        // 3. 设置每个节点的子节点列表
        for (T node : list) {
            I id = getId.apply(node);
            List<T> children = parentIdMap.get(id);
            if (children != null) {
                // 对子节点进行排序(如果需要)
                sortList(children, comparator);
                // 设置子节点列表
                setSub.accept(node, children);
            }
        }

        // 4. 提取根节点(父 ID 为 null 或者父 ID 不存在于节点映射中的节点)
        List<T> roots = list.stream()
                .filter(node -> {
                    I parentId = getParentId.apply(node);
                    return parentId == null || !idNodeMap.containsKey(parentId);
                })
                .collect(Collectors.toList());

        // 对根节点进行排序(如果需要)
        sortList(roots, comparator);

        return roots;
    }

    /**
     * 对列表进行排序
     *
     * @param list       要排序的列表
     * @param comparator 比较器(可选)
     * @param <T>        列表元素类型
     */
    private static <T> void sortList(List<T> list, Comparator<T> comparator) {
        if (comparator != null && list != null && !list.isEmpty()) {
            list.sort(comparator);
        }
    }

    /**
     * 编码形式的树构建,当节点的编码不以任何其他节点编码为前缀时,该节点为根节点。
     * 所有节点的子节点列表必须不为 null。
     *
     * @param list       所有节点的列表
     * @param getCode    获取节点编码的函数
     * @param comparator 同级节点排序的比较器(可选)
     * @param getSub     获取子节点列表的函数
     * @param setSub     设置子节点列表的函数
     * @param <T>        节点类型
     * @param <C>        节点编码类型(必须是 String 或其子类)
     * @return 树的根节点列表
     */
    public static <T, C extends String> List<T> buildByCode(
            @NonNull List<T> list,
            @NonNull Function<T, C> getCode,
            @Nullable Comparator<T> comparator,
            @NonNull Function<T, List<T>> getSub,
            @NonNull BiConsumer<T, List<T>> setSub) {

        // 按照编码排序,将节点分组
        List<T> sortedCodeList = list.stream()
                .sorted(Comparator.comparing(getCode))
                .collect(Collectors.toList());

        Map<C, List<T>> codeGroupMap = new HashMap<>();
        C flagCode = null;
        for (T item : sortedCodeList) {
            C currentCode = getCode.apply(item);
            if (flagCode == null || !currentCode.startsWith(flagCode)) {
                flagCode = currentCode;
            }
            codeGroupMap.computeIfAbsent(flagCode, k -> new ArrayList<>()).add(item);
        }

        // 构建树
        List<T> tree = new ArrayList<>();
        codeGroupMap.forEach((k, v) -> tree.add(buildNodeByCode(v, getCode, getSub, setSub)));
        sortTree(tree, comparator, getSub);

        return tree;
    }

    /**
     * 构建节点(编码形式),用于辅助 buildByCode 方法
     *
     * @param subList 子节点列表
     * @param getCode 获取编码的函数
     * @param getSub  获取子节点列表的函数
     * @param setSub  设置子节点列表的函数
     * @param <T>     节点类型
     * @param <C>     编码类型
     * @return 构建好的节点
     */
    private static <T, C extends String> T buildNodeByCode(
            List<T> subList,
            Function<T, C> getCode,
            Function<T, List<T>> getSub,
            BiConsumer<T, List<T>> setSub) {

        if (subList.isEmpty()) {
            throw new IllegalStateException("树构建异常:子节点列表为空");
        }

        // 反转列表,方便子节点找父节点
        Collections.reverse(subList);
        for (int i = 0; i < subList.size() - 1; i++) {
            T child = subList.get(i);
            T parent = findParentByCode(child, subList.subList(i + 1, subList.size()), getCode);
            List<T> children = getSub.apply(parent);
            if (children == null) {
                children = new ArrayList<>();
                setSub.accept(parent, children);
            }
            children.add(child);
        }

        return subList.get(subList.size() - 1);
    }

    /**
     * 根据编码查找父节点
     *
     * @param currentNode 当前节点
     * @param subList     子节点列表
     * @param getCode     获取编码的函数
     * @param <T>         节点类型
     * @param <C>         编码类型
     * @return 父节点
     */
    private static <T, C extends String> T findParentByCode(
            T currentNode,
            List<T> subList,
            Function<T, C> getCode) {

        C currentCode = getCode.apply(currentNode);
        for (T node : subList) {
            C parentCode = getCode.apply(node);
            if (currentCode.startsWith(parentCode) && !currentCode.equals(parentCode)) {
                return node;
            }
        }
        throw new IllegalStateException("构建异常:未找到父节点");
    }

    /**
     * 对树进行排序
     *
     * @param tree       树的根节点列表
     * @param comparator 比较器
     * @param getSub     获取子节点列表的函数
     * @param <T>        节点类型
     */
    private static <T> void sortTree(
            List<T> tree,
            Comparator<T> comparator,
            Function<T, List<T>> getSub) {

        sortList(tree, comparator);
        for (T node : tree) {
            List<T> sub = getSub.apply(node);
            if (sub != null && !sub.isEmpty()) {
                sortTree(sub, comparator, getSub);
            }
        }
    }

    /**
     * 获取指定节点的所有父节点
     *
     * @param list              节点列表
     * @param ids               目标节点 ID 列表
     * @param idExtractor       获取节点 ID 的函数
     * @param parentIdExtractor 获取父节点 ID 的函数
     * @param containSelf       是否包含自身
     * @param <T>               节点类型
     * @param <R>               ID 类型
     * @return 父节点列表
     */
    public static <T, R> List<T> getParent(
            List<T> list,
            List<R> ids,
            Function<? super T, ? extends R> idExtractor,
            Function<? super T, ? extends R> parentIdExtractor,
            boolean containSelf) {

        if (CollectionUtils.isEmpty(list) || CollectionUtils.isEmpty(ids)) {
            return new ArrayList<>();
        }

        // 构建 ID -> 节点的映射,避免重复查找
        Map<R, T> idNodeMap = list.stream()
                .collect(Collectors.toMap(idExtractor, Function.identity()));

        Set<R> parentIds = new HashSet<>();
        Deque<R> stack = new LinkedList<>(ids);

        while (!stack.isEmpty()) {
            R currentId = stack.pop();
            if (!parentIds.contains(currentId)) {
                parentIds.add(currentId);
                T node = idNodeMap.get(currentId);
                if (node != null) {
                    R parentId = parentIdExtractor.apply(node);
                    if (parentId != null && !parentIds.contains(parentId)) {
                        stack.push(parentId);
                    }
                }
            }
        }

        return list.stream()
                .filter(node -> parentIds.contains(idExtractor.apply(node)))
                .collect(Collectors.toList());
    }

    /**
     * 获取指定节点的所有子节点
     *
     * @param list              节点列表
     * @param ids               目标节点 ID 列表
     * @param idExtractor       获取节点 ID 的函数
     * @param parentIdExtractor 获取父节点 ID 的函数
     * @param containSelf       是否包含自身
     * @param <T>               节点类型
     * @param <R>               ID 类型
     * @return 子节点列表
     */
    public static <T, R> List<T> getChildren(
            List<T> list,
            List<R> ids,
            Function<? super T, ? extends R> idExtractor,
            Function<? super T, ? extends R> parentIdExtractor,
            boolean containSelf) {

        if (CollectionUtils.isEmpty(list) || CollectionUtils.isEmpty(ids)) {
            return new ArrayList<>();
        }

        Map<R, T> idNodeMap = list.stream()
                .collect(Collectors.toMap(idExtractor, Function.identity(), (existing, replacement) -> existing));

        Map<R, List<T>> parentIdMap = list.stream()
                .collect(Collectors.groupingBy(parentIdExtractor));

        Set<R> resultIds = new HashSet<>();
        if (containSelf) {
            resultIds.addAll(ids);
        }

        Queue<R> queue = new LinkedList<>(ids);
        while (!queue.isEmpty()) {
            R parentId = queue.poll();
            List<T> children = parentIdMap.get(parentId);
            if (children != null) {
                for (T child : children) {
                    R childId = idExtractor.apply(child);
                    if (!resultIds.contains(childId)) {
                        resultIds.add(childId);
                        queue.add(childId);
                    }
                }
            }
        }

        return list.stream()
                .filter(node -> resultIds.contains(idExtractor.apply(node)))
                .collect(Collectors.toList());
    }

    /**
     * 在树中搜索所有符合条件的节点
     *
     * @param tree   树的根节点列表
     * @param getKey 获取节点属性的函数
     * @param getSub 获取子节点列表的函数
     * @param key    要匹配的属性值
     * @param <T>    节点类型
     * @param <I>    属性值类型
     * @return 符合条件的节点列表
     */
    public static <T, I> List<T> searchTree4All(
            @NonNull List<T> tree,
            @NonNull Function<T, I> getKey,
            @NonNull Function<T, List<T>> getSub,
            @NonNull I key) {

        List<T> matched = new ArrayList<>();
        Queue<T> queue = new LinkedList<>(tree);

        while (!queue.isEmpty()) {
            T node = queue.poll();
            I nodeKey = getKey.apply(node);
            if (nodeKey != null && nodeKey.equals(key)) {
                matched.add(node);
            }
            List<T> sub = getSub.apply(node);
            if (sub != null && !sub.isEmpty()) {
                queue.addAll(sub);
            }
        }

        return matched;
    }

    /**
     * 在树中搜索第一个符合条件的节点
     *
     * @param tree   树的根节点列表
     * @param getKey 获取节点属性的函数
     * @param getSub 获取子节点列表的函数
     * @param key    要匹配的属性值
     * @param <T>    节点类型
     * @param <I>    属性值类型
     * @return 符合条件的节点(Optional)
     */
    public static <T, I> Optional<T> searchTree4One(
            @NonNull List<T> tree,
            @NonNull Function<T, I> getKey,
            @NonNull Function<T, List<T>> getSub,
            @NonNull I key) {

        Queue<T> queue = new LinkedList<>(tree);

        while (!queue.isEmpty()) {
            T node = queue.poll();
            I nodeKey = getKey.apply(node);
            if (nodeKey != null && nodeKey.equals(key)) {
                return Optional.of(node);
            }
            List<T> sub = getSub.apply(node);
            if (sub != null && !sub.isEmpty()) {
                queue.addAll(sub);
            }
        }

        return Optional.empty();
    }

    /**
     * 将树转换为列表
     *
     * @param tree   树的根节点列表
     * @param getSub 获取子节点列表的函数
     * @param <T>    节点类型
     * @return 展开的节点列表
     */
    public static <T> List<T> tree2List(
            @NonNull List<T> tree,
            @NonNull Function<T, List<T>> getSub) {

        List<T> list = new ArrayList<>();
        Queue<T> queue = new LinkedList<>(tree);

        while (!queue.isEmpty()) {
            T node = queue.poll();
            list.add(node);
            List<T> sub = getSub.apply(node);
            if (sub != null && !sub.isEmpty()) {
                queue.addAll(sub);
            }
        }

        return list;
    }

    /**
     * 为树节点添加随机 ID
     *
     * @param tree        树的根节点列表
     * @param getSub      获取子节点列表的函数
     * @param setId       设置节点 ID 的函数
     * @param setParentId 设置父节点 ID 的函数
     * @param parentId    初始父节点 ID(根节点的父 ID,一般为 0 或 null)
     * @param idCounter   ID 计数器(可选)
     * @param <T>         节点类型
     */
    public static <T> void addRandomId(
            @NonNull List<T> tree,
            @NonNull Function<T, List<T>> getSub,
            @NonNull BiConsumer<T, Long> setId,
            @NonNull BiConsumer<T, Long> setParentId,
            @Nullable Long parentId,
            @Nullable MutableLong idCounter) {

        if (idCounter == null) {
            idCounter = new MutableLong(1L);
        }
        if (parentId == null) {
            parentId = 0L;
        }

        Queue<T> queue = new LinkedList<>(tree);
        Map<T, Long> parentMap = new HashMap<>();

        while (!queue.isEmpty()) {
            T node = queue.poll();
            long id = idCounter.longValue();
            idCounter.increment();
            setId.accept(node, id);
            setParentId.accept(node, parentMap.getOrDefault(node, parentId));
            List<T> sub = getSub.apply(node);
            if (sub != null && !sub.isEmpty()) {
                for (T child : sub) {
                    parentMap.put(child, id);
                    queue.add(child);
                }
            }
        }
    }

    /**
     * 根据名称过滤树节点
     *
     * @param tree         树的根节点列表
     * @param getSub       获取子节点列表的函数
     * @param getName      获取节点名称的函数
     * @param searchName   要搜索的名称
     * @param reserveChild 父节点匹配时是否保留所有子节点
     * @param <T>          节点类型
     */
    public static <T> void filterTreeByName(
            @NonNull List<T> tree,
            @NonNull Function<T, List<T>> getSub,
            @NonNull Function<T, String> getName,
            @NonNull String searchName,
            @NonNull Boolean reserveChild) {

        if (!StringUtils.hasLength(searchName)) {
            return;
        }

        Queue<T> queue = new LinkedList<>(tree);
        while (!queue.isEmpty()) {
            T node = queue.poll();
            String name = getName.apply(node);
            List<T> sub = getSub.apply(node);

            if (reserveChild && StringUtils.hasLength(name) && name.contains(searchName)) {
                continue;
            }

            if (sub != null && !sub.isEmpty()) {
                filterTreeByName(sub, getSub, getName, searchName, reserveChild);
            }

            if ((sub == null || sub.isEmpty()) && (name == null || !name.contains(searchName))) {
                tree.remove(node);
            }
        }
    }

    /**
     * 根据 ID 过滤树节点
     *
     * @param tree         树的根节点列表
     * @param getSub       获取子节点列表的函数
     * @param getId        获取节点 ID 的函数
     * @param searchId     要搜索的 ID
     * @param reserveChild 父节点匹配时是否保留所有子节点
     * @param <T>          节点类型
     */
    public static <T> void filterTreeById(
            @NonNull List<T> tree,
            @NonNull Function<T, List<T>> getSub,
            @NonNull Function<T, Long> getId,
            @NonNull Long searchId,
            @NonNull Boolean reserveChild) {

        Queue<T> queue = new LinkedList<>(tree);
        while (!queue.isEmpty()) {
            T node = queue.poll();
            Long id = getId.apply(node);
            List<T> sub = getSub.apply(node);

            if (reserveChild && id != null && id.equals(searchId)) {
                continue;
            }

            if (sub != null && !sub.isEmpty()) {
                filterTreeById(sub, getSub, getId, searchId, reserveChild);
            }

            if ((sub == null || sub.isEmpty()) && (id == null || !id.equals(searchId))) {
                tree.remove(node);
            }
        }
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值