1. 场景
在日常开发中,我们会常常使用到树形结构表,比如:菜单、部门、分组等。在关系型数据库中对树形结构表的设计通常来说有4种方案,这些方案能在某些方面提高查询效率或更新效率,q各位看官请自行搜索。但是,由于部分关系型数据库本身不支持联级查询(比如mysql),所以一般情况下,我们需要在代码当中完成树形结构列表的组装。
通常情况下,我们会使用邻近表的方式,在树形结构表中存储当前数据行的父级数据行id(pId或parentId),然后使用树形工具类完成树形组装。
2.说明
将一个数据列表组装成树形结构,需将每个数据行都是视为树的一个节点,根据id,parentId两个标识寻找数据之间的上下级关系。每个节点都有一个子节点列表(children)。
自定义的工具类支持对同一个父节点的节点进行排序,以及由下而上地统计节点计数。除了必须的children属性外,不侵入其它属性(比如用于支持排序和层级数量统计的字典)。
3.实现
3.1 定义树节点基类
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
/**
* 树节点
* <p>
* 只增加children属性,不侵入多余字段(如排序、计数等)
* 配合TreeUtils使用
*
* @author jackLee
*/
@Data
public abstract class TreeNode<T> {
/**
* 子节点列表
*/
private List<T> children = new ArrayList<>();
}
2. 定义自定义注解
import java.lang.annotation.*;
/**
* 树字段注解
*
* @author ysgj
*/
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TreeField {
/**
* 字段类型
*/
FieldType type();
/**
* 字段类型
*/
enum FieldType {
ID,
PARENT_ID,
SORT,
COUNT
}
}
3. 树形结构工具类
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import com.ysgj.common.exception.ServiceException;
import com.ysgj.common.utils.StringUtils;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;
/**
* 树形工具类
* <p>
* 不侵入字段(children字段是必须的), 通过注解指定节点标识字段
* 支持排序和层级计数统计
*
* @author jackLee
*/
@Slf4j
public class TreeUtils {
// 默认字段名
private static final String ID_FIELD = "id";
private static final String PARENT_FIELD = "parentId";
public static <T extends TreeNode<T>> List<T> build(List<T> list, Object... rootIds) {
return build(list, true, rootIds);
}
public static <T extends TreeNode<T>> List<T> build(List<T> list, boolean withCount, Object... rootIds) {
if (CollUtil.size(list) <= 1) {
return list;
}
long start = System.currentTimeMillis();
List<Field> fieldList = findAllFields(list);
Map<TreeField.FieldType, Field> treeFieldMap = findTreeFields(fieldList);
List<T> tree = buildTreeSmart(list, rootIds != null ? Arrays.asList(rootIds) : null, treeFieldMap, withCount);
log.info("TreeUtils.build costs time: {} ms", System.currentTimeMillis() - start);
return tree;
}
// 反射获取所有字段
private static <T> List<Field> findAllFields(List<T> list) {
T t = list.get(0);
List<Field> fieldList = new ArrayList<>();
Class<?> tempClass = t.getClass();
while (tempClass != null) {
Field[] fields = tempClass.getDeclaredFields();
if (fields.length > 0) {
fieldList.addAll(Arrays.asList(fields));
}
tempClass = tempClass.getSuperclass();
}
return fieldList;
}
// 获取树形结构相关字段
private static Map<TreeField.FieldType, Field> findTreeFields(List<Field> fieldList) {
Map<TreeField.FieldType, Field> fieldMap = new HashMap<>();
for (Field field : fieldList) {
if (Modifier.isStatic(field.getModifiers())) {
continue;
}
if (StringUtils.equals(ID_FIELD, field.getName())) {
fieldMap.putIfAbsent(TreeField.FieldType.ID, field);
} else if (StringUtils.equals(PARENT_FIELD, field.getName())) {
fieldMap.putIfAbsent(TreeField.FieldType.PARENT_ID, field);
}
TreeField treeField = field.getAnnotation(TreeField.class);
if (treeField != null) {
fieldMap.put(treeField.type(), field);
}
}
//校验字段
checkTreeFields(fieldMap);
return fieldMap;
}
// 字段类型检查
private static void checkTreeFields(Map<TreeField.FieldType, Field> fieldMap) {
Assert.notNull(fieldMap.get(TreeField.FieldType.ID), () -> new ServiceException("请设置树节点标识"));
Assert.notNull(fieldMap.get(TreeField.FieldType.PARENT_ID), () -> new ServiceException("请设置树节点父标识"));
Field sortField = fieldMap.get(TreeField.FieldType.SORT);
if (sortField != null) {
boolean isAllowType = Integer.class.equals(sortField.getType());
Assert.isTrue(isAllowType, () -> new ServiceException("排序字段仅支持Integer类型"));
}
Field countField = fieldMap.get(TreeField.FieldType.COUNT);
if (countField != null) {
boolean isAllowType = Integer.class.equals(countField.getType());
Assert.isTrue(isAllowType, () -> new ServiceException("计数字段仅支持Integer类型"));
}
for (Field field : fieldMap.values()) {
field.setAccessible(true);
}
}
/**
* 转为树形结构
*
* @param treeNodes 节点列表
* @return 树形结构列表
*/
private static <T extends TreeNode<T>> List<T> buildTreeSmart(List<T> treeNodes, List<Object> rootIds, Map<TreeField.FieldType, Field> fieldMap, boolean withCount) {
Field idField = fieldMap.get(TreeField.FieldType.ID);
Field parentIdField = fieldMap.get(TreeField.FieldType.PARENT_ID);
Field countField = fieldMap.get(TreeField.FieldType.COUNT);
Field sortField = fieldMap.get(TreeField.FieldType.SORT);
List<T> tree = new ArrayList<>();
try {
Map<Object, T> mapping = new HashMap<>();
for (T node : treeNodes) {
if (idField.get(node) != null) {
mapping.put(idField.get(node), node);
}
}
//自动查找根节点
if (CollUtil.isEmpty(rootIds)) {
rootIds = findRootIds(treeNodes, idField, parentIdField);
}
for (T node : treeNodes) {
Object parentId = parentIdField.get(node);
if (rootIds.contains(parentId)) {
tree.add(node);
continue;
}
T parent = mapping.get(parentId);
if (parent != null) {
if (parent.getChildren() == null) {
parent.setChildren(new ArrayList<>());
}
parent.getChildren().add(node);
}
}
if (withCount && countField != null) {
count(tree, countField);
}
sort(tree, sortField);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
return tree;
}
// 自动查找root节点
private static <T extends TreeNode<T>> List<Object> findRootIds(List<T> treeNodes, Field idField, Field parentIdField) throws IllegalAccessException {
List<Object> ids = new ArrayList<>();
List<Object> pIds = new ArrayList<>();
for (T t : treeNodes) {
ids.add(idField.get(t));
pIds.add(parentIdField.get(t));
}
pIds.removeAll(ids);
return pIds;
}
//统计计数
private static <T extends TreeNode<T>> Integer count(List<T> tree, Field countField) throws IllegalAccessException {
int count = 0;
if (tree != null && tree.size() > 0) {
for (T treeNode : tree) {
int nodeCount = (Integer) countField.get(treeNode);
int childrenCounts = count(treeNode.getChildren(), countField);
countField.set(treeNode, nodeCount + childrenCounts);
count += nodeCount + childrenCounts;
}
}
return count;
}
//排序
@SuppressWarnings("unchecked")
private static <T extends TreeNode<T>> void sort(List<T> tree, Field sortField) {
if (sortField != null) {
for (T treeNode : tree) {
if (treeNode.getChildren() != null && treeNode.getChildren().size() > 0) {
treeNode.getChildren().sort(comparator(sortField));
}
}
tree.sort(comparator(sortField));
}
}
private static Comparator comparator(Field sortField) {
return (o1, o2) -> {
try {
int sort1 = sortField.get(o1) != null ? (Integer) sortField.get(o1) : 0;
int sort2 = sortField.get(o2) != null ? (Integer) sortField.get(o2) : 0;
return sort1 - sort2;
} catch (IllegalAccessException e) {
log.error(e.getMessage(), e);
}
return 0;
};
}
}
点个赞点个赞!