TreeUtil构造和操作树结构的工具类

后端编程中我们经常需要将数据库中数据构造成树结构返回给前端展示,还有就是对树结构进行一些操作,例如获取树结构中的某个节点,某个节点的父节点、根节点、孩子节点、兄弟节点等。

定义一个树工具接口

package org.mlx.util.base;

import java.util.List;


public interface TreeUtil<T> {

    /**
     * 获取树列表
     *
     * @return 树列表
     */
    List<T> getTreeList();

    /**
     * 获取节点
     * @param id 节点id
     * @return 节点
     */
    T getNode(Object id);

    /**
     * 节点列表:获取所有节点列表
     *
     * @return 节点列表
     */
    List<T> getNodeList();

    /**
     * 节点列表:节点及其子节点列表
     * @param id 节点id
     * @return 节点列表
     */
    List<T> getNodeList(Object id);

    /**
     * 获取孩子节点
     * @param id 节点id
     * @return 孩子节点列表
     */
    List<T> getChildList(Object id);

    /**
     * 获取后代节点
     * @param id 节点id
     * @return 后代节点列表
     */
    List<T> getOffspringList(Object id);

    /**
     * 获取兄弟节点
     * @param id 节点id
     * @return 兄弟节点列表
     */
    List<T> getBrotherList(Object id);

    /**
     * 获取直接父节点
     * @param id 节点id
     * @return 父节点
     */
    T getParent(Object id);

    /**
     * 获取所有父节点
     * @param id 节点id
     * @return 父节点列表
     */
    List<T> getParentList(Object id);

    /**
     * 获取节点的根节点
     * @param id 节点id
     * @return 根节点
     */
    T getRoot(Object id);

    /**
     * 获取层级节点
     * @param level 层级
     * @return 层级节点列表
     */
    List<T> getLevelList(int level);

    /**
     * 获取叶子节点
     * @return 叶子节点列表
     */
    List<T> getLeafList();

    /**
     * 获取初始层级
     * @return 初始层级
     */
    int getInitLevel();

    /**
     * 获取最大层级
     * @return 最大层级
     */
    int getMaxLevel();

    /**
     * 获取层级(层级数量)
     * @return 层级数量
     */
    int getLevel();

    /**
     * 判断是否有游离节点
     * @return ture有,false无
     */
    boolean hasFreeNode();

    /**
     * 获取游离节点集合
     * @return 游离节点集合
     */
    List<T> getFreeNodeList();
}

树结构的核心概念是树节点
树节点的属性一般为:
id属性(节点的唯一标识),pid属性(父节点的唯一标识),childList属性(孩子列表属性),一些额外的属性。
树节点保存在数据库中的是id属性、pid属性、一些额外的属性。后端读取数据库数据后通过id属性和pid属性的关系设置childList属性。构造树结构其实就是为每个节点设置childList属性。

构造树结构前,需要一个工具类为不确定类型的树节点设置属性和读取属性

定义JavaBean接口,用于获取和设置对象的属性

package org.mlx.util.base.bean;

import java.util.Set;

/**
 * JavaBean封装某个类型,用于对该类型的对象读取属性和设置属性
 */
public interface JavaBean {

    /**
     * 获取bean对象的fieldName属性
     */
    Object get(Object bean, String fieldName);

    /**
     * 设置bean对象的fieldName属性为value
     */
    void set(Object bean, String fieldName, Object value);

    /**
     * 判断是否包含fieldName属性
     */
    boolean containsField(String fieldName);

    /**
     * 获取JavaBean的所有属性的集合
     */
    Set<String> fieldSet();

    /**
     * 获取JavaBean的封装类型
     */
    Class<?> getType();

    /**
     * 获取fieldName属性的类型
     */
    Class<?> getFieldType(String fieldName);
}


操作实体类的javaBean实现

package org.mlx.util.base.bean;

import java.beans.BeanInfo;
import java.beans.IntrospectionException;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class EntityBean implements JavaBean{

    private final Class<?> type;

    private final Map<String, PropertyDescriptor> descriptorMap;

    public EntityBean(Class<?> beanClass){

        this.type = beanClass;

        this.descriptorMap = new HashMap<>();

        BeanInfo beanInfo;
        try {
            beanInfo = Introspector.getBeanInfo(beanClass);
        } catch (IntrospectionException e) {
            throw new RuntimeException(String.format("获取【%s】setter失败,原因是获取【%s】BeanInfo发生异常,异常信息:%s", beanClass, beanClass, e.getMessage()));
        }

        PropertyDescriptor[] propertyDescriptors = beanInfo.getPropertyDescriptors();
        if (propertyDescriptors != null){
            for (PropertyDescriptor descriptor : propertyDescriptors){
                if (descriptor.getWriteMethod() != null && descriptor.getReadMethod() != null){
                    this.descriptorMap.put(descriptor.getName(), descriptor);
                }
            }
        }
    }


    public Object get(Object bean, String fieldName){
        Method getter = descriptorMap.get(fieldName).getReadMethod();
        if (getter == null){
            throw new RuntimeException(String.format("【%s】类没有【%s】属性", type, fieldName));
        }

        try {
            return getter.invoke(bean);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(String.format("JavaBean获取【%s】类型的【%s】属性值失败:没有getter方法访问权限", type, fieldName));
        } catch (InvocationTargetException e) {
            throw new RuntimeException(String.format("JavaBean获取【%s】类型的【%s】属性值失败:getter方法执行时发生异常;异常信息:【%s】", type, fieldName, e.getMessage()));
        }
    }

    public void set(Object bean, String fieldName, Object value){
        Method setter = descriptorMap.get(fieldName).getWriteMethod();
        if (setter == null){
            throw new RuntimeException(String.format("【%s】类没有【%s】属性", type.toString(), fieldName));
        }
        try {
            setter.invoke(bean, value);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(String.format("JavaBean设置【%s】类型的【%s】属性值失败:没有setter方法访问权限", type, fieldName));
        } catch (InvocationTargetException e) {
            throw new RuntimeException(String.format("JavaBean设置【%s】类型的【%s】属性值失败:setter方法执行时发生异常;异常信息:【%s】", type, fieldName, e.getMessage()));
        }
    }

    @Override
    public boolean containsField(String fieldName) {
        return descriptorMap.containsKey(fieldName);
    }

    public Set<String> fieldSet(){
        return descriptorMap.keySet();
    }

    public Class<?> getType() {
        return type;
    }

    public Class<?> getFieldType(String fieldName){
        PropertyDescriptor descriptor = descriptorMap.get(fieldName);
        if (descriptor == null){
            throw new RuntimeException(String.format("【%s】类没有【%s】属性", type, fieldName));
        }
        return descriptor.getReadMethod().getReturnType();
    }
}

操作Map的JavaBean实现

package org.mlx.util.base.bean;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;


public class MapBean implements JavaBean{

    private final Class<?> type;

    private final Method getter;

    private final Method setter;

    private final Map<String, Class<?>> fieldTypeMap;

    public MapBean(){
        this.type = Map.class;
        try {
            this.getter = Map.class.getMethod("get", Object.class);
            this.setter = Map.class.getMethod("put", Object.class, Object.class);
        } catch (NoSuchMethodException e) {
            throw new RuntimeException("读取java.util.Map类型的get方法或set方法发生异常", e);
        }
        this.fieldTypeMap = new HashMap<>();
    }

    public MapBean(Map<?, ?> example){
        Objects.requireNonNull(example);
        this.type = Map.class;
        try {
            this.getter = Map.class.getMethod("get", Object.class);
            this.setter = Map.class.getMethod("put", Object.class, Object.class);
        } catch (NoSuchMethodException e) {
            throw new RuntimeException("读取java.util.Map类型的get方法或set方法发生异常", e);
        }
        this.fieldTypeMap = new HashMap<>();
        example.forEach((k, v) -> {
            if (k != null){
                if (v == null){
                    fieldTypeMap.put(k.toString(), null);
                }
                else {
                    fieldTypeMap.put(k.toString(), v.getClass());
                }
            }
        });
    }

    public Object get(Object bean, String fieldName){

        try {
            return getter.invoke(bean, fieldName);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(String.format("MapBean获取【%s】类型的【%s】属性值失败:没有【%s】方法访问权限", type, fieldName, getter.getName()));
        } catch (InvocationTargetException e) {
            throw new RuntimeException(String.format("MapBean获取【%s】类型的【%s】属性值失败:【%s】方法执行时发生异常;异常信息:【%s】", type, fieldName, getter.getName(), e.getMessage()));
        }
    }

    public void set(Object bean, String fieldName, Object value){

        try {
            setter.invoke(bean, fieldName, value);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(String.format("MapBean设置【%s】类型的【%s】属性值失败:没有【%s】方法访问权限", type, fieldName, setter.getName()));
        } catch (InvocationTargetException e) {
            throw new RuntimeException(String.format("MapBean设置【%s】类型的【%s】属性值失败:【%s】方法执行时发生异常;异常信息:【%s】", type, fieldName, setter.getName(), e.getMessage()));
        }
    }

    @Override
    public boolean containsField(String fieldName) {
        return fieldTypeMap.containsKey(fieldName);
    }

    public Set<String> fieldSet(){
        return fieldTypeMap.keySet();
    }

    public Class<?> getType() {
        return type;
    }

    public Class<?> getFieldType(String fieldName){
        return fieldTypeMap.get(fieldName);

    }
}

树形工具类的实现:
id属性、pid属性、孩子列表属性默认名称分别是id、pid、childList,如果需要改变,重新配置即可。
层级属性、根节点id属性、路径列表属性默认是空,配置了才会生效。
提示:根节点id属性使用场景是用于区分不同根节点下的节点,有时候我们会用一个类型字段用于区分不同根节点下的节点,根节点id属性可以充当这个类型字段

package org.mlx.util.base.tree;

import org.mlx.util.base.TreeUtil;
import org.mlx.util.base.bean.EntityBean;
import org.mlx.util.base.bean.JavaBean;
import org.mlx.util.base.bean.MapBean;

import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.*;


public class DefaultTreeUtil<T> implements TreeUtil<T> {

    // 节点列表
    protected Collection<T> nodeCollection;

    // 节点map
    protected Map<Object, T> nodeMap;

    // 游离节点map
    protected Map<Object, T> freeNodeMap;

    // 根节点列表
    protected List<T> rootList;

    // id属性名称
    protected String idFieldName;

    // 父id属性名称
    protected String pidFieldName;

    // 层级属性名称
    protected String levelFieldName;

    // 根节点id属性名称,节点的根节点id属性等于其所在树的根节点的id
    protected String rootIdFieldName;

    // 孩子节点列表属性名称
    protected String childListFieldName;

    // 路径列表属性名称(根节点到当前节点的id列表)
    protected String pathListFieldName;

    // 根节点值集合
    protected Collection<Object> rootIdValueCollection;

    // 根节点父节点id属性值,用于确定哪些节点是根节点
    protected Object rootPidValue;

    // 初始层级(根节点的层级,默认是1)
    protected Integer initLevel;

    // 最大层级
    protected Integer maxLevel;

    // 用于排序的排序属性名称:为孩子节点进行排序
    protected String sortRule;

    // 排序的比较器
    protected Comparator<T> comparator;

    // 当没有孩子节点时,用于确定为孩子节点列表属性值设置为空列表还是null
    protected Boolean emptyChildIsNull;

    // 固定属性值map(为节点对象的一些固定属性设置固定的值)
    protected Map<String, Object> fixedFieldValueMap;

    protected JavaBean javaBean;

    protected Map<String, Object> configuration;

    public DefaultTreeUtil(Collection<T> nodeCollection) {
        this.nodeCollection = nodeCollection;
        this.configuration = new HashMap<>();
        build();

    }

    public DefaultTreeUtil(Collection<T> nodeCollection, String idFieldName, String pidFieldName) {
        this.nodeCollection = nodeCollection;
        Map<String, Object> configMap = new HashMap<>();
        configMap.put("idFieldName", idFieldName);
        configMap.put("pidFieldName", pidFieldName);
        this.configuration = configMap;
        build();
    }

    public DefaultTreeUtil(Collection<T> nodeCollection, String idFieldName, String pidFieldName, Object rootPidValue) {
        this.nodeCollection = nodeCollection;
        Map<String, Object> configMap = new HashMap<>();
        configMap.put("idFieldName", idFieldName);
        configMap.put("pidFieldName", pidFieldName);
        configMap.put("rootPidValue", rootPidValue);
        this.configuration = configMap;
        build();

    }

    public DefaultTreeUtil(Collection<T> nodeCollection, Map<String, Object> configuration) {
        this.nodeCollection = nodeCollection;
        if (configuration == null){
            this.configuration = new HashMap<>();
        }
        else {
            this.configuration = configuration;
        }
        build();
    }


    protected final void build(){

        // 读取配置
        readerConfig();

        // 初始化配置
        initConfig();

        // 构建树结构
        buildTrees();

        // 检查游离节点
        checkFreeNode();

    }


    protected void readerConfig(){
        this.idFieldName = (String) configuration.getOrDefault("idFieldName", "id");
        this.pidFieldName = (String) configuration.getOrDefault("pidFieldName", "pid");
        this.childListFieldName = (String) configuration.getOrDefault("childListFieldName", "childList");
        this.levelFieldName = (String) configuration.getOrDefault("levelFieldName", null);
        this.rootIdFieldName = (String) configuration.getOrDefault("rootIdFieldName", null);
        this.pathListFieldName = (String) configuration.getOrDefault("pathListFieldName", null);
        this.rootPidValue = configuration.getOrDefault("rootPidValue", null);
        this.rootIdValueCollection = (Collection<Object>) configuration.getOrDefault("rootIdValueCollection", null);
        this.initLevel = (int) configuration.getOrDefault("initLevel", 1);
        this.sortRule = (String) configuration.getOrDefault("sortRule", null);
        this.emptyChildIsNull = (boolean) configuration.getOrDefault("emptyChildIsNull", false);
        this.fixedFieldValueMap = (Map<String, Object>) configuration.getOrDefault("fixedFieldValueMap", null);
    }

    protected void initConfig(){

        if (nodeCollection == null || nodeCollection.isEmpty()){
            throw new RuntimeException("构建树形工具类失败:节点列表为空");
        }

        // 设置JavaBean
        T node0 = nodeCollection.iterator().next();
        if (node0 instanceof Map){
            javaBean = new MapBean();
        }
        else {
            javaBean = new EntityBean(node0.getClass());
        }

        // 最大层级设置初始值
        maxLevel = initLevel;

        // 设置比较器
        if (sortRule == null || sortRule.isEmpty()){
            comparator = null;
        }
        else {
            String[] sortFiledRuleList = sortRule.split(";");
            comparator = null;
            for (String sortFiledRule : sortFiledRuleList){
                if ("".equals(sortFiledRule)){
                    continue;
                }
                if (comparator == null){
                    comparator = new NodeComparator<>(javaBean, sortFiledRule);
                }
                else {
                    comparator = comparator.thenComparing(new NodeComparator<>(javaBean, sortFiledRule));
                }
            }
        }

        // 设置节点map
        nodeMap = new HashMap<>(nodeCollection.size());

        // 构造树结构前,先把所有节点设置进入游离节点map中
        freeNodeMap = new HashMap<>(nodeCollection.size());
        for (T node : nodeCollection){
            freeNodeMap.put(javaBean.get(node, idFieldName), node);
        }

        // 设置根节点id集合
        if (rootIdValueCollection == null || rootIdValueCollection.isEmpty()){
            if (rootIdValueCollection == null){
                rootIdValueCollection = new HashSet<>();
            }

            if (rootPidValue == null || "".equals(rootPidValue)){
                for (T node : nodeCollection){
                    Object pid = javaBean.get(node, pidFieldName);
                    if (null == pid || "".equals(pid)){
                        rootIdValueCollection.add(javaBean.get(node, idFieldName));
                    }
                }
            }
            else {
                for (T node : nodeCollection){
                    if (rootPidValue.equals(javaBean.get(node, pidFieldName))){
                        rootIdValueCollection.add(javaBean.get(node, idFieldName));
                    }
                }
            }

            if (rootIdValueCollection.isEmpty()){
                throw new RuntimeException(String.format("指定【%s】属性值为【%s】为的根节点,没有找到匹配此条件的根节点", pidFieldName, rootPidValue == null || "".equals(rootPidValue) ? "空": rootPidValue));
            }
        }
        else {
            for (Object rootIdValue : rootIdValueCollection){
                if (!freeNodeMap.containsKey(rootIdValue)){
                    throw new RuntimeException(String.format("检查根节点异常:节点列表中不存在id属性值为【%s】的节点", rootIdValue));
                }
            }
        }
    }

    protected void buildTrees(){

        rootList = new ArrayList<>(rootIdValueCollection.size());

        for (Object rootIdValue : rootIdValueCollection){

            T root = freeNodeMap.get(rootIdValue);

            if (root == null){
                throw new RuntimeException(idFieldName + "属性值为【" + rootIdValue + "】的根节点可能是其他根节点的下级节点");
            }

            rootList.add(root);

            buildTree(root, rootIdValue, initLevel);
        }

        if (comparator != null){
            rootList.sort(comparator);
        }
    }

    protected void buildTree(T root, Object rootId, int level){

        if (rootIdFieldName != null) javaBean.set(root, rootIdFieldName, rootId);
        if (levelFieldName != null) javaBean.set(root, levelFieldName, level);

        // 设置固定属性值
        if (fixedFieldValueMap != null && !fixedFieldValueMap.isEmpty()) fixedFieldValueMap.forEach((k, v) -> javaBean.set(root, k, v));

        if (level > this.maxLevel) this.maxLevel = level;

        Object id = javaBean.get(root, idFieldName);

        freeNodeMap.remove(id);
        nodeMap.put(id, root);

        if (pathListFieldName != null){
            List<Object> pathList;
            if (id.equals(rootId)){
                pathList = new ArrayList<>();
                pathList.add(id);
            }
            else {
                Object pid = javaBean.get(root, pidFieldName);
                Object parentPathList = javaBean.get(nodeMap.get(pid), pathListFieldName);
                pathList = new ArrayList<>((List<Object>) parentPathList);
                pathList.add(id);
            }
            javaBean.set(root, pathListFieldName, pathList);
        }

        List<T> childList = new ArrayList<>();

        for (T node : nodeCollection){
            if (id.equals(javaBean.get(node, pidFieldName))){
                childList.add(node);
            }
        }

        if (comparator != null && !childList.isEmpty()) childList.sort(comparator);

        if (childList.isEmpty()){
            if (emptyChildIsNull){
                javaBean.set(root, childListFieldName, null);
            }else {
                javaBean.set(root, childListFieldName, childList);
            }
        }else {
            javaBean.set(root, childListFieldName, childList);
        }

        if (!childList.isEmpty()){
            for (T child : childList){
                buildTree(child, rootId, level + 1);
            }
        }
    }

    protected void checkFreeNode(){
        if (!freeNodeMap.isEmpty()){
            nodeCollection = nodeMap.values();
        }
    }


    public List<T> getTreeList(){
        return rootList;
    }


    public T getNode(Object id){

        if (!nodeMap.containsKey(id)){
            return null;
        }

        return nodeMap.get(id);
    }


    @Override
    public List<T> getNodeList(){
        return new ArrayList<>(nodeMap.values());
    }

    @Override
    public List<T> getNodeList(Object id){
        if (!nodeMap.containsKey(id)){
            return new ArrayList<>();
        }
        return getTreeNodeList(id);
    }

    protected List<T> getTreeNodeList(Object treeId){
        List<T> nodeList = new ArrayList<>();
        T node = nodeMap.get(treeId);
        if (node == null){
            return nodeList;
        }

        nodeList.add(node);

        List<T> childList =(List<T>) javaBean.get(node, childListFieldName);

        if (childList != null && !childList.isEmpty()){
            for (T child : childList){
                nodeList.addAll(getTreeNodeList(javaBean.get(child, idFieldName)));
            }
        }
        return nodeList;
    }

    @Override
    public List<T> getChildList(Object id){

        T node = nodeMap.get(id);
        if (node == null){
            return new ArrayList<>();
        }
        return (List<T>) javaBean.get(node, childListFieldName);
    }

    @Override
    public List<T> getOffspringList(Object id){

        if (!nodeMap.containsKey(id)){
            return new ArrayList<>();
        }

        // 获取指定节点为根节点的子树的所有节点
        List<T> nodeList = getTreeNodeList(id);

        // 过滤掉指定节点
        nodeList.removeIf(node -> id.equals(javaBean.get(node, idFieldName)));

        return nodeList;
    }

    @Override
    public List<T> getBrotherList(Object id){
        T node = nodeMap.get(id);
        List<T> brotherList = new ArrayList<>();
        if (node == null){
            return brotherList;
        }

        if (rootIdValueCollection.contains(id)){

            brotherList.addAll(rootList);

        }else {

            Object pid = javaBean.get(node, pidFieldName);

            T parent = nodeMap.get(pid);

            brotherList.addAll((List<T>) javaBean.get(parent, childListFieldName));
        }

        brotherList.removeIf(brother -> id.equals(javaBean.get(brother, idFieldName)));

        return brotherList;
    }

    @Override
    public T getParent(Object id){

        T node = nodeMap.get(id);

        if (node == null){
            return null;
        }

        if (rootIdValueCollection.contains(id)){
            return null;
        }

        Object pid = javaBean.get(node, pidFieldName);

        return nodeMap.get(pid);
    }

    @Override
    public List<T> getParentList(Object id){

        T node = nodeMap.get(id);

        List<T> parentList = new ArrayList<>();

        if (node == null){
            return parentList;
        }


        if (rootIdValueCollection.contains(id)){
            return parentList;
        }

        Object pid = javaBean.get(node, pidFieldName);

        while (!rootIdValueCollection.contains(pid)){
            T parent = nodeMap.get(pid);
            parentList.add(parent);
            pid = javaBean.get(parent, pidFieldName);
        }

        parentList.add(nodeMap.get(pid));

        return parentList;
    }

    @Override
    public T getRoot(Object id){

        T node = nodeMap.get(id);
        if (node == null){
            return null;
        }

        if (rootIdValueCollection.contains(id)){
            return node;
        }

        if (rootIdFieldName == null || rootIdFieldName.isEmpty()){
            List<T> parentList = getParentList(id);
            return parentList.get(parentList.size() - 1);
        }
        else {
            Object rootId = javaBean.get(node, rootIdFieldName);
            return nodeMap.get(rootId);
        }
    }

    @Override
    public List<T> getLevelList(int level){

        List<T> nodeList = new ArrayList<>();
        if (level < initLevel || level > maxLevel){
            return nodeList;
        }
        for (T node : nodeCollection){
            if (javaBean.get(node, levelFieldName).equals(level)){
                nodeList.add(node);
            }
        }
        return nodeList;
    }

    @Override
    public List<T> getLeafList(){
        // 寻找根节点
        List<T> leafList = new ArrayList<>();
        List<T> childList;
        for (T node : nodeCollection){
            childList =(List<T>) javaBean.get(node, childListFieldName);
            if (childList == null || childList.isEmpty()){
                leafList.add(node);
            }
        }
        return leafList;
    }

    @Override
    public int getInitLevel() {
        return initLevel;
    }

    @Override
    public int getMaxLevel(){
        return maxLevel;
    }

    @Override
    public int getLevel() {
        return maxLevel - initLevel + 1;
    }

    @Override
    public boolean hasFreeNode(){
        return !freeNodeMap.isEmpty();
    }

    @Override
    public List<T> getFreeNodeList(){
        return new ArrayList<>(freeNodeMap.values());
    }


    public TreeUtil<T> getDeepCopyTreeUtil(){
        return (TreeUtil<T>) Proxy.newProxyInstance(this.getClass().getClassLoader(), new Class[]{TreeUtil.class}, new DeepCopyInvocationHandler(this));
    }

    /**
     * 深度复制调用处理器
     */
    private static class DeepCopyInvocationHandler implements InvocationHandler {

        private final Object source;

        public DeepCopyInvocationHandler(Object source) {
            this.source = source;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            try {
                Object result = method.invoke(source, args);
                if (method.getReturnType().isPrimitive()){
                    return result;
                }
                else {
                    if (result == null){
                        return null;
                    }

                    if (!Serializable.class.isAssignableFrom(result.getClass())){
                        throw new RuntimeException("深度复制对象失败,原因是【"+ result.getClass() + "】类型未实现序列化接口");
                    }
                    ObjectOutputStream out = null;
                    ObjectInputStream input = null;
                    try {
                        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                        out = new ObjectOutputStream(byteArrayOutputStream);
                        out.writeObject(result);

                        //然后反序列化,从流里读取出来,即完成复制
                        ByteArrayInputStream bi = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
                        input = new ObjectInputStream(bi);
                        return input.readObject();
                    } catch (IOException e) {
                        throw new RuntimeException(String.format("深度复制对象失败,原因是发生了IO异常,异常类型【%s】异常信息【%s】", e.getClass(), e.getMessage()));
                    } catch (ClassNotFoundException e) {
                        throw new RuntimeException(String.format("深度复制对象失败,原因是输入流重新读取对象时发现对象类型不存在,异常类型【%s】异常信息【%s】", e.getClass(), e.getMessage()));
                    }finally {
                        try {
                            if (out != null){
                                out.close();
                            }
                            if (input != null){
                                input.close();
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }catch (Throwable e){
                throw new RuntimeException("深度复制发生异常", e);
            }
        }
    }

    /**
     * 节点比较器,节点排序使用
     * @param <TT>
     */
    private static class NodeComparator<TT> implements Comparator<TT>{

        private final JavaBean javaBean;

        private final String fieldName;

        private final Boolean asc;

        private final Boolean nullFirst;

        public NodeComparator(JavaBean javaBean, String sortFiledRule) {
            String[] split = sortFiledRule.split(",");
            if (split.length < 1 || split.length > 3){
                throw new RuntimeException(sortFiledRule + "是错误的排序字段规则");
            }
            this.javaBean = javaBean;
            this.fieldName = split[0];
            if (split.length == 1){
                this.asc = true;
                this.nullFirst = true;
            }
            else if (split.length == 2){
                if ("".equals(split[1])){
                    this.asc = true;
                }
                else {
                    this.asc = "asc".equalsIgnoreCase(split[1]);
                }
                this.nullFirst = true;
            }
            else {
                if ("".equals(split[1])){
                    this.asc = true;
                }
                else {
                    this.asc = "asc".equalsIgnoreCase(split[1]);
                }
                if ("".equals(split[2])){
                    this.nullFirst = true;
                }
                else {
                    this.nullFirst = "first".equalsIgnoreCase(split[2]);
                }
            }

        }

        @Override
        public int compare(TT o1, TT o2) {
            Object v1 = javaBean.get(o1, fieldName);
            Object v2 = javaBean.get(o2, fieldName);
            int r;
            if (v1 == null){
                r = (v2 == null) ? 0 : (nullFirst ? -1 : 1);
            }
            else if (v2 == null){
                r = nullFirst ? 1: -1;
            }
            else {
                r = v1.toString().compareTo(v2.toString());
            }

            if (asc){
                return r;
            }
            else {
                return -1*r;
            }
        }
    }
}

测试例子

package org.mlx.util.base;

import org.mlx.util.base.tree.DefaultTreeUtil;

import java.io.Serializable;
import java.util.*;

public class UtilTest {


    public static void main(String[] args) {
        test04();
    }

    // 测试深度复制树
    public static void test05() {
        Node node1 = new Node("111", null);
        Node node11 = new Node("111-111", "111");
        Node node111 = new Node("111-111-111", "111-111");
        Node node112 = new Node("111-111-222", "111-111");
        Node node12 = new Node("111-222", "111");
        Node node121 = new Node("111-222-111", "111-222");
        Node node122 = new Node("111-222-222", "111-222");

        Node node2 = new Node("222", null);
        Node node21 = new Node("222-111", "222");
        Node node22 = new Node("222-222", "222");
        Node node23 = new Node("222-333", "222");

        Node node3 = new Node("333", null);
        Node node31 = new Node("333-111", "333");
        Node node32 = new Node("333-222", "333");
        Node node33 = new Node("333-333", "333");

        Map<String, Object> configMap = new LinkedHashMap<>();
        configMap.put("levelFieldName", "level");
        configMap.put("rootIdFieldName", "rootId");
        configMap.put("pidListFieldName", "pidList");

        TreeUtil<Node> treeUtil = new DefaultTreeUtil<>(Arrays.asList(
                node1, node11, node12, node111, node112, node121, node122,
                node2, node21, node22, node23,
                node3, node31, node32, node33), configMap).getDeepCopyTreeUtil();

        System.out.println("树列表");
        List<Node> treeList = treeUtil.getTreeList();
        for (Node temp : treeList){
            System.out.println(temp);
        }

        String nodeId = "111-222";

        System.out.println("指定节点");
        Node node = treeUtil.getNode(nodeId);
        System.out.println(node);

        System.out.println("全部节点");
        List<Node> nodeAllList = treeUtil.getNodeList();
        for (Node temp : nodeAllList){
            System.out.println(temp);
        }

        System.out.println("指定节点及其所有子节点");
        List<Node> nodeList = treeUtil.getNodeList(nodeId);
        for (Node temp : nodeList){
            System.out.println(temp);
        }

        System.out.println("兄弟节点");
        List<Node> brotherList = treeUtil.getBrotherList(nodeId);
        for (Node temp : brotherList){
            System.out.println(temp);
        }

        System.out.println("父节点");
        Node parent = treeUtil.getParent(nodeId);
        System.out.println(parent);

        System.out.println("父节点列表");
        List<Node> parentList = treeUtil.getParentList(nodeId);
        for (Node temp : parentList){
            System.out.println(temp);
        }

        System.out.println("孩子节点");
        List<Node> childList = treeUtil.getChildList(nodeId);
        for (Node temp : childList){
            System.out.println(temp);
        }

        System.out.println("后代节点");
        List<Node> offspringList = treeUtil.getOffspringList(nodeId);
        for (Node temp : offspringList){
            System.out.println(temp);
        }

        System.out.println("所属根节点");
        Node root = treeUtil.getRoot(nodeId);
        System.out.println(root);

        System.out.println("叶子节点");
        List<Node> leafList = treeUtil.getLeafList();
        for (Node temp : leafList){
            System.out.println(temp);
        }

        System.out.println("指定层级节点");
        List<Node> levelList = treeUtil.getLevelList(2);
        for (Node temp : levelList){
            System.out.println(temp);
        }

        System.out.println("是否有游离节点");
        boolean hasFreeNode = treeUtil.hasFreeNode();
        System.out.println(hasFreeNode);

        System.out.println("游离节点列表");
        List<Node> freeNodeList = treeUtil.getFreeNodeList();
        for (Node temp : freeNodeList){
            System.out.println(temp);
        }

        System.out.println("初始层级");
        int initLevel = treeUtil.getInitLevel();
        System.out.println(initLevel);

        System.out.println("层级数量");
        int level = treeUtil.getLevel();
        System.out.println(level);

        System.out.println("最大层级");
        int maxLevel = treeUtil.getMaxLevel();
        System.out.println(maxLevel);

        Node n1 = treeUtil.getNode(nodeId);
        System.out.println(node == n1);
    }

    // 测试获取方法
    public static void test04() {
        Node node1 = new Node("111", null);
        Node node11 = new Node("111-111", "111");
        Node node111 = new Node("111-111-111", "111-111");
        Node node112 = new Node("111-111-222", "111-111");
        Node node12 = new Node("111-222", "111");
        Node node121 = new Node("111-222-111", "111-222");
        Node node122 = new Node("111-222-222", "111-222");

        Node node2 = new Node("222", null);
        Node node21 = new Node("222-111", "222");
        Node node22 = new Node("222-222", "222");
        Node node23 = new Node("222-333", "222");

        Node node3 = new Node("333", null);
        Node node31 = new Node("333-111", "333");
        Node node32 = new Node("333-222", "333");
        Node node33 = new Node("333-333", "333");

        Map<String, Object> configMap = new LinkedHashMap<>();
        configMap.put("levelFieldName", "level");
        configMap.put("rootIdFieldName", "rootId");
        configMap.put("pidListFieldName", "pidList");

        DefaultTreeUtil<Node> treeUtil = new DefaultTreeUtil<>(Arrays.asList(
                node1, node11, node12, node111, node112, node121, node122,
                node2, node21, node22, node23,
                node3, node31, node32, node33), configMap);

        System.out.println("树列表");
        List<Node> treeList = treeUtil.getTreeList();
        for (Node temp : treeList){
            System.out.println(temp);
        }

        String nodeId = "111-222";

        System.out.println("指定节点");
        Node node = treeUtil.getNode(nodeId);
        System.out.println(node);

        System.out.println("全部节点");
        List<Node> nodeAllList = treeUtil.getNodeList();
        for (Node temp : nodeAllList){
            System.out.println(temp);
        }

        System.out.println("指定节点及其所有子节点");
        List<Node> nodeList = treeUtil.getNodeList(nodeId);
        for (Node temp : nodeList){
            System.out.println(temp);
        }

        System.out.println("兄弟节点");
        List<Node> brotherList = treeUtil.getBrotherList(nodeId);
        for (Node temp : brotherList){
            System.out.println(temp);
        }

        System.out.println("父节点");
        Node parent = treeUtil.getParent(nodeId);
        System.out.println(parent);

        System.out.println("父节点列表");
        List<Node> parentList = treeUtil.getParentList(nodeId);
        for (Node temp : parentList){
            System.out.println(temp);
        }

        System.out.println("孩子节点");
        List<Node> childList = treeUtil.getChildList(nodeId);
        for (Node temp : childList){
            System.out.println(temp);
        }

        System.out.println("后代节点");
        List<Node> offspringList = treeUtil.getOffspringList(nodeId);
        for (Node temp : offspringList){
            System.out.println(temp);
        }

        System.out.println("所属根节点");
        Node root = treeUtil.getRoot(nodeId);
        System.out.println(root);

        System.out.println("叶子节点");
        List<Node> leafList = treeUtil.getLeafList();
        for (Node temp : leafList){
            System.out.println(temp);
        }

        System.out.println("指定层级节点");
        List<Node> levelList = treeUtil.getLevelList(2);
        for (Node temp : levelList){
            System.out.println(temp);
        }

        System.out.println("是否有游离节点");
        boolean hasFreeNode = treeUtil.hasFreeNode();
        System.out.println(hasFreeNode);

        System.out.println("游离节点列表");
        List<Node> freeNodeList = treeUtil.getFreeNodeList();
        for (Node temp : freeNodeList){
            System.out.println(temp);
        }

        System.out.println("初始层级");
        int initLevel = treeUtil.getInitLevel();
        System.out.println(initLevel);

        System.out.println("层级数量");
        int level = treeUtil.getLevel();
        System.out.println(level);

        System.out.println("最大层级");
        int maxLevel = treeUtil.getMaxLevel();
        System.out.println(maxLevel);

    }

    // 实体类树节点测试例子
    public static void test03() {
        Node node1 = new Node("111", null);
        Node node11 = new Node("111-111", "111");
        Node node111 = new Node("111-111-111", "111-111");
        Node node112 = new Node("111-111-222", "111-111");
        Node node12 = new Node("111-222", "111");
        Node node121 = new Node("111-222-111", "111-222");
        Node node122 = new Node("111-222-222", "111-222");

        Node node2 = new Node("222", null);
        Node node21 = new Node("222-111", "222");
        Node node22 = new Node("222-222", "222");
        Node node23 = new Node("222-333", "222");

        Node node3 = new Node("333", null);
        Node node31 = new Node("333-111", "333");
        Node node32 = new Node("333-222", "333");
        Node node33 = new Node("333-333", "333");

        Map<String, Object> configMap = new LinkedHashMap<>();
        configMap.put("levelFieldName", "level");
        configMap.put("rootIdFieldName", "rootId");
        configMap.put("pidListFieldName", "pidList");

        DefaultTreeUtil<Node> treeUtil = new DefaultTreeUtil<>(Arrays.asList(
                node1, node11, node12, node111, node112, node121, node122,
                node2, node21, node22, node23,
                node3, node31, node32, node33), configMap);

        for (Node tree : treeUtil.getTreeList()) {
            System.out.println(tree);
        }

    }

    // map树节点测试例子(pidList)
    public static void test02() {
        Map<String, Object> node1 = new LinkedHashMap<>();
        node1.put("id", "111");
        node1.put("pid", null);
        node1.put("city", "北京");
        node1.put("age", 15);

        Map<String, Object> node11 = new LinkedHashMap<>();
        node11.put("id", "111-111");
        node11.put("pid", "111");
        node11.put("city", "北京");
        node11.put("age", 600);

        Map<String, Object> node12 = new LinkedHashMap<>();
        node12.put("id", "111-222");
        node12.put("pid", "111");
        node12.put("city", "北京");
        node12.put("age", 700);

        Map<String, Object> node111 = new LinkedHashMap<>();
        node111.put("id", "111-111-111");
        node111.put("pid", "111-111");
        node111.put("city", "北京");
        node111.put("age", 500);

        Map<String, Object> node112 = new LinkedHashMap<>();
        node112.put("id", "111-111-222");
        node112.put("pid", "111-111");
        node112.put("city", "上海");
        node112.put("age", 800);

        Map<String, Object> node121 = new LinkedHashMap<>();
        node121.put("id", "111-222-111");
        node121.put("pid", "111-222");
        node121.put("city", "上海");
        node121.put("age", 400);

        Map<String, Object> node122 = new LinkedHashMap<>();
        node122.put("id", "111-222-222");
        node122.put("pid", "111-222");
        node122.put("city", "上海");
        node122.put("age", 900);

        Map<String, Object> node2 = new LinkedHashMap<>();
        node2.put("id", "222");
        node2.put("pid", null);
        node2.put("city", "北京");
        node2.put("age", 25);

        Map<String, Object> node21 = new LinkedHashMap<>();
        node21.put("id", "222-111");
        node21.put("pid", "222");
        node21.put("city", "北京");
        node21.put("age", 550);

        Map<String, Object> node22 = new LinkedHashMap<>();
        node22.put("id", "222-222");
        node22.put("pid", "222");
        node22.put("city", "北京");
        node22.put("age", 570);

        Map<String, Object> node23 = new LinkedHashMap<>();
        node23.put("id", "222-333");
        node23.put("pid", "222");
        node23.put("city", "北京");
        node23.put("age", 560);

        Map<String, Object> node3 = new LinkedHashMap<>();
        node3.put("id", "333");
        node3.put("pid", null);
        node3.put("city", "北京");
        node3.put("age", 20);

        Map<String, Object> node31 = new LinkedHashMap<>();
        node31.put("id", "333-111");
        node31.put("pid", "333");
        node31.put("city", "北京");
        node31.put("age", 200);

        Map<String, Object> node32 = new LinkedHashMap<>();
        node32.put("id", "333-222");
        node32.put("pid", "333");
        node32.put("city", "北京");
        node32.put("age", 150);

        Map<String, Object> node33 = new LinkedHashMap<>();
        node33.put("id", "333-333");
        node33.put("pid", "333");
        node33.put("city", "北京");
        node33.put("age", 250);

        Map<String, Object> configMap = new LinkedHashMap<>();
        configMap.put("levelFieldName", "level");
        configMap.put("rootIdFieldName", "rootId");
        configMap.put("pidListFieldName", "pidList");

        DefaultTreeUtil<Map<String, Object>> treeUtil = new DefaultTreeUtil<>(Arrays.asList(
                node1, node11, node12, node111, node112, node121, node122,
                node2, node21, node22, node23,
                node3, node31, node32, node33), configMap);

        for (Map<String, Object> tree : treeUtil.getTreeList()) {
            System.out.println(tree);
        }

    }

    // map树节点测试例子(排序)
    public static void test01() {
        Map<String, Object> node1 = new LinkedHashMap<>();
        node1.put("id", "111");
        node1.put("pid", null);
        node1.put("city", "北京");
        node1.put("age", 15);

        Map<String, Object> node11 = new LinkedHashMap<>();
        node11.put("id", "111-111");
        node11.put("pid", "111");
        node11.put("city", "北京");
        node11.put("age", 600);

        Map<String, Object> node12 = new LinkedHashMap<>();
        node12.put("id", "111-222");
        node12.put("pid", "111");
        node12.put("city", "北京");
        node12.put("age", 700);

        Map<String, Object> node13 = new LinkedHashMap<>();
        node13.put("id", "111-333");
        node13.put("pid", "111");
        node13.put("city", "北京");
        node13.put("age", 500);

        Map<String, Object> node14 = new LinkedHashMap<>();
        node14.put("id", "111-444");
        node14.put("pid", "111");
        node14.put("city", "上海");
        node14.put("age", 800);

        Map<String, Object> node15 = new LinkedHashMap<>();
        node15.put("id", "111-555");
        node15.put("pid", "111");
        node15.put("city", "上海");
        node15.put("age", 400);

        Map<String, Object> node16 = new LinkedHashMap<>();
        node16.put("id", "111-666");
        node16.put("pid", "111");
        node16.put("city", "上海");
        node16.put("age", 900);

        Map<String, Object> node2 = new LinkedHashMap<>();
        node2.put("id", "222");
        node2.put("pid", null);
        node2.put("city", "北京");
        node2.put("age", 25);

        Map<String, Object> node21 = new LinkedHashMap<>();
        node21.put("id", "222-111");
        node21.put("pid", "222");
        node21.put("city", "北京");
        node21.put("age", 550);

        Map<String, Object> node22 = new LinkedHashMap<>();
        node22.put("id", "222-222");
        node22.put("pid", "222");
        node22.put("city", "北京");
        node22.put("age", 570);

        Map<String, Object> node23 = new LinkedHashMap<>();
        node23.put("id", "222-333");
        node23.put("pid", "222");
        node23.put("city", "北京");
        node23.put("age", 560);

        Map<String, Object> node3 = new LinkedHashMap<>();
        node3.put("id", "333");
        node3.put("pid", null);
        node3.put("city", "北京");
        node3.put("age", 20);

        Map<String, Object> node31 = new LinkedHashMap<>();
        node31.put("id", "333-111");
        node31.put("pid", "333");
        node31.put("city", "北京");
        node31.put("age", 200);

        Map<String, Object> node32 = new LinkedHashMap<>();
        node32.put("id", "333-222");
        node32.put("pid", "333");
        node32.put("city", "北京");
        node32.put("age", 150);

        Map<String, Object> node33 = new LinkedHashMap<>();
        node33.put("id", "333-333");
        node33.put("pid", "333");
        node33.put("city", "北京");
        node33.put("age", 250);

        Map<String, Object> configMap = new LinkedHashMap<>();
//        configMap.put("sortRule", "city;age");
        configMap.put("sortRule", "city;age,desc");
        configMap.put("levelFieldName", "level");
        configMap.put("rootIdFieldName", "rootId");
        configMap.put("pidListFieldName", "pidList");

        DefaultTreeUtil<Map<String, Object>> treeUtil = new DefaultTreeUtil<>(Arrays.asList(
                node1, node11, node12, node13, node14, node15, node16,
                node2, node21, node22, node23,
                node3, node31, node32, node33), configMap);

        for (Map<String, Object> tree : treeUtil.getTreeList()) {
            System.out.println(tree);
        }

    }

    // 实体类测试节点
    public static class Node implements Serializable {

        private String id;

        private String pid;

        private List<Node> childList;

        private String rootId;

        private int level;

        private List<String> pidList;

        public Node() {
        }

        public Node(String id, String pid) {
            this.id = id;
            this.pid = pid;
        }

        public String getId() {
            return id;
        }

        public void setId(String id) {
            this.id = id;
        }

        public String getPid() {
            return pid;
        }

        public void setPid(String pid) {
            this.pid = pid;
        }

        public List<Node> getChildList() {
            return childList;
        }

        public void setChildList(List<Node> childList) {
            this.childList = childList;
        }

        public String getRootId() {
            return rootId;
        }

        public void setRootId(String rootId) {
            this.rootId = rootId;
        }

        public int getLevel() {
            return level;
        }

        public void setLevel(int level) {
            this.level = level;
        }

        public List<String> getPidList() {
            return pidList;
        }

        public void setPidList(List<String> pidList) {
            this.pidList = pidList;
        }

        @Override
        public String toString() {
            return "Node{" +
                    "id='" + id + '\'' +
                    ", pid='" + pid + '\'' +
                    ", rootId='" + rootId + '\'' +
                    ", level=" + level +
                    ", pidList=" + pidList +
                    ", childList=" + childList +
                    '}';
        }
    }
}

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值