1. TreeModel
package com.yl.tree;
import org.springframework.beans.BeanUtils;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* 树结构通用模型类
*
* @author liuxb
* @date 2021/11/30 11:18
*/
public class TreeModel<T> {
/**
* 节点列表
*/
private List<T> nodeList;
/**
* 自定义实现比较器
*/
private Comparator<T> comparator;
/**
* 节点id字段名
*/
private String idName = "id";
/**
* 父节点id标签字段名
*/
private String pIdName = "pId";
/**
* 子节点集合字段名
*/
private String childrenName = "children";
/**
* 实体类名
*/
private Class<T> entityClass;
/**
* 构造器,使用前判断非空
*
* @param nodeList
*/
public TreeModel(List<T> nodeList) {
if (nodeList == null || nodeList.isEmpty()) {
throw new RuntimeException("节点集合不能为null或空元素");
}
this.nodeList = nodeList;
this.entityClass = (Class<T>) nodeList.get(0).getClass();
}
/**
* 构造器,使用前判断非空
*
* @param nodeList
* @param comparator
*/
public TreeModel(List<T> nodeList, Comparator<T> comparator) {
if (nodeList == null || nodeList.isEmpty()) {
throw new RuntimeException("节点集合不能为null或空元素");
}
this.entityClass = (Class<T>) nodeList.get(0).getClass();
this.nodeList = nodeList;
this.comparator = comparator;
}
/**
* 获取全部节点的树
*
* @return
*/
public List<T> treeList() {
List<T> list = new ArrayList<>();
for (T tree : this.nodeList) {
if (isNullOrBlank(tree, this.pIdName)) {
T t = getT();
BeanUtils.copyProperties(tree, t);
setFieldValue(t, this.childrenName, getChildren(getFieldValue(tree, this.idName)));
list.add(t);
}
}
if (this.comparator != null) {
Collections.sort(list, this.comparator);
}
return list;
}
/**
* 根据节点id获取树中该节点下的子树
*
* @param id 节点id
* @return
*/
public List<T> getChildren(Object id) {
List<T> list = new ArrayList<>();
for (T tree : this.nodeList) {
if (id.equals(getFieldValue(tree, this.pIdName))) {
T t = getT();
BeanUtils.copyProperties(tree, t);
setFieldValue(t, this.childrenName, getChildren(getFieldValue(tree, this.idName)));
list.add(t);
}
}
if (this.comparator != null) {
Collections.sort(list, this.comparator);
}
return list;
}
/**
* 获取对象
*
* @return
*/
private T getT() {
T t = null;
try {
t = this.entityClass.newInstance();
} catch (InstantiationException | IllegalAccessException e) {
throw new RuntimeException(e);
}
return t;
}
/**
* 将树结构转为单个节点的集合
*
* @param treeList
* @param allList
* @return
*/
public void childrenToList(List<T> treeList, List<T> allList) {
for (T tree : treeList) {
try {
T t = this.entityClass.newInstance();
BeanUtils.copyProperties(tree, t);
setFieldValue(t, this.childrenName, new ArrayList<>());
allList.add(t);
List<T> children = (List<T>) getFieldValue(tree, this.childrenName);
if (children != null && children.size() > 0) {
childrenToList(children, allList);
}
} catch (Exception e) {
throw new RuntimeException("树转为节点异常", e);
}
}
}
/**
* 根据节点id获取全部的父节点列表
*
* @param id 节点id
* @return
*/
public Set<T> getParentList(Object id) {
Set<T> parentSet = new HashSet<>();
T obj = null;
for (T tree : this.nodeList) {
if (id.equals(getFieldValue(tree, this.idName))) {
obj = tree;
break;
}
}
// id 查询不存在
if (Objects.isNull(obj)) {
return parentSet;
}
Object pIdValue = getFieldValue(obj, this.pIdName);
if (pIdValue == null || "".equals(pIdValue)) {
return parentSet;
}
getParent(pIdValue, parentSet);
return parentSet;
}
/**
* 递归查找全部父节点
*
* @param id 节点id
* @param parentSet 父节点列表
*/
private void getParent(Object id, Set<T> parentSet) {
T obj = null;
for (T tree : this.nodeList) {
if (id.equals(getFieldValue(tree, this.idName))) {
obj = tree;
break;
}
}
// id 查询不存在
if (Objects.isNull(obj)) {
return;
}
parentSet.add(obj);
//继续往上查找
Object pIdValue = getFieldValue(obj, this.pIdName);
if (pIdValue == null || "".equals(pIdValue)) {
return;
}
getParent(pIdValue, parentSet);
}
/**
* 根据对象的pid值类型,判断是否不能为null或者空白
*
* @param tree
* @param pIdName
* @return
*/
private boolean isNullOrBlank(T tree, String pIdName) {
Class<? extends Object> clazz = tree.getClass();
Field field = null;
try {
field = clazz.getDeclaredField(pIdName);
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
String name = field.getType().getSimpleName();
// id,pid字段类型,也就可能5种 int long Integer Long String
if (name.equals("int") || name.equals("long") || name.equals("Integer") || name.equals("Long")) {
return getFieldValue(tree, pIdName) == null;
} else {
Object value = getFieldValue(tree, pIdName);
return value == null || "".equals(value);
}
}
/**
* 根据字段名获取对象字段值
*
* @param bean
* @param fieldName
* @return
* @throws Exception
*/
private Object getFieldValue(Object bean, String fieldName) {
Class<? extends Object> clazz = bean.getClass();
try {
Field field = clazz.getDeclaredField(fieldName);
field.setAccessible(true);
return field.get(bean);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException("无法根据对象[" + clazz.getSimpleName() + "]的属性名[" + fieldName + "]获取属性值");
}
}
/**
* 设置对象的字段名,字段值
*
* @param bean
* @param fieldName
* @param value
* @throws Exception
*/
private void setFieldValue(Object bean, String fieldName, Object value) {
Class<? extends Object> clazz = bean.getClass();
try {
Field field = clazz.getDeclaredField(fieldName);
field.setAccessible(true);
field.set(bean, value);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException("无法给对象[" + clazz.getSimpleName() + "]的属性名[" + fieldName + "]设置值");
}
}
public List<T> getNodeList() {
return nodeList;
}
public void setNodeList(List<T> nodeList) {
this.nodeList = nodeList;
}
public Comparator<T> getComparator() {
return comparator;
}
public void setComparator(Comparator<T> comparator) {
this.comparator = comparator;
}
public String getIdName() {
return idName;
}
public void setIdName(String idName) {
this.idName = idName;
}
public String getPIdName() {
return pIdName;
}
public void setPIdName(String pIdName) {
this.pIdName = pIdName;
}
public String getChildrenName() {
return childrenName;
}
public void setChildrenName(String childrenName) {
this.childrenName = childrenName;
}
}
2. 测试
tree.json
[
{"id":1, "name": "生产部", "pId":null, "order":1},
{"id":2, "name": "生产部一室", "pId":1, "order":2},
{"id":3, "name": "生产部二室", "pId":1, "order":1},
{"id":4, "name": "开发部", "pId":null, "order":2},
{"id":5, "name": "开发部一室", "pId":4, "order":2},
{"id":6, "name": "开发部二室", "pId":4, "order":1},
{"id":7, "name": "测试部", "pId":null, "order":1},
{"id":8, "name": "销售部", "pId":null, "order":4},
{"id":9, "name": "销售部一室", "pId":8, "order":1},
{"id":10, "name": "销售部二室", "pId":8, "order":2},
{"id":11, "name": "办公室", "pId":-1, "order":0}
{"id":12, "name": "销售部一室11", "pId":9, "order":0}
]
Department .java
@Data
public class Department {
private Integer id;
private String name;
private Integer pId;
private Integer order;
private List<Department> children;
}
测试类
//读取tree.json文件
InputStream inputStream = ClassLoader.getSystemClassLoader().getResourceAsStream("tree.json");
String jsonContent = IOUtils.toString(inputStream);
//解析json字符串
List<Department> list = JSONArray.parseArray(jsonContent).toJavaList(Department.class);
// System.out.println(list);
TreeModel<Department> treeModel = new TreeModel(list);
List<Department> treeList = treeModel.treeList();
treeList.forEach(o-> System.out.println(o));
List<Department> child = treeModel.getChildren(8);
System.out.println("======");
List<Department> allList = new ArrayList<>();
treeModel.childrenToList(treeList, allList);
allList.forEach(o-> System.out.println(o));
System.out.println("======");
System.out.println(child);
System.out.println("======");
//注意 id为Integer, 不能传字符串
Set<Department> parentList = treeModel.getParentList(12);
System.out.println(parentList);