需求
平级的数组构造成树形结构的数组。
工具类
reflect.go
// @Author xzx 2023/3/21 17:21:00
package util
import "reflect"
func GetId[T any](t *T) int {
return int(reflect.ValueOf(*t).FieldByName("Id").Int())
}
func GetParentId[T any](t *T) int {
return int(reflect.ValueOf(*t).FieldByName("ParentId").Int())
}
tree.go
// Package util @Author xzx 2023/3/21 17:21:00
package util
// BuildTree
// BuildTree[T any]
// @Description 构建树形结构
// @Author xzx 2023-03-21 17:22:46
// @Param list
// @Param rootId
// @Param setChildren
// @Return []*T
//
func BuildTree[T any](list []*T, rootId int, setChildren func(parent *T, children []*T)) []*T {
retList := make([]*T, 0, 0)
for _, t := range list {
if GetParentId(t) == rootId {
recursionFn(list, t, setChildren)
retList = append(retList, t)
}
}
return retList
}
func recursionFn[T any](list []*T, t *T, setChildren func(parent *T, children []*T)) {
// 收集t的子节点
childList := getChildList(list, t)
if childList != nil && len(childList) > 0 {
setChildren(t, childList)
}
// 遍历t的子节点
for _, tChild := range childList {
// 如何子节点有子节点,递归收集子节点的子节点集合
if hasChild(list, tChild) {
recursionFn(list, tChild, setChildren)
}
}
}
//
// getChildList[T any]
// @Description 收集pid = t.id的子节点
// @Author xzx 2023-03-21 17:14:15
// @Param list
// @Param t
// @Return []*T
//
func getChildList[T any](list []*T, t *T) []*T {
tList := make([]*T, 0, 0)
id := GetId(t)
for _, n := range list {
if GetParentId(n) == id {
tList = append(tList, n)
}
}
return tList
}
func hasChild[T any](list []*T, item *T) bool {
return len(getChildList(list, item)) > 0
}
测试
package test
import (
"fmt"
"github.com/goccy/go-json"
"info_manage/pkg/util"
"testing"
)
type Menu struct {
Id int `json:"id"`
ParentId int `json:"parent_id"`
Name string `json:"name"`
Children []*Menu `json:"children"`
}
func TestTreeBuild(t *testing.T) {
menus := []*Menu{
{Id: 1, ParentId: 0, Name: "Menu 1"},
{Id: 2, ParentId: 0, Name: "Menu 2"},
{Id: 3, ParentId: 1, Name: "Submenu 1"},
{Id: 4, ParentId: 1, Name: "Submenu 2"},
{Id: 5, ParentId: 3, Name: "Subsubmenu 1"},
}
tree := util.BuildTree(menus, 0, func(p *Menu, c []*Menu) {
p.Children = c
})
bytes, err := json.Marshal(tree)
if err != nil {
fmt.Println(err.Error())
}
fmt.Println(string(bytes))
}
=== RUN TestTree
[{"id":1,"parent_id":0,"name":"Menu 1","children":[{"id":3,"parent_id":1,"name":"Submenu 1","children":[{"id":5,"parent_id":3,"name":"Subsubmenu 1","children":null}]},{"id":4,"parent_id":1,"name":"Submenu 2","children":null}]},{"id":2,"parent_id":0,"name":"Menu 2","children":null}]
--- PASS: TestTree (0.00s)
PASS
优化
使用goroutine
异步并行收集,并优化部分函数
// Package util @Author xzx 2023/3/21 17:21:00
package util
import "sync"
// BuildTree
// BuildTree[T any]
// @Description 构建树形结构
// @Author xzx 2023-03-21 17:22:46
// @Param list
// @Param rootId
// @Param setChildren
// @Return []*T
//
func BuildTree[T any](list []*T, rootId int, setChildren func(parent *T, children []*T)) []*T {
// 等待所有 goroutine 执行结束
wg := &sync.WaitGroup{}
// 存储最终的结果集
retList := make([]*T, 0, 0)
for _, t := range list {
// 找到根节点,进行递归构建
if GetParentId(t) == rootId {
// 递归构建出当前根节点下的所有子节点
wg.Add(1)
go recursionFn(list, t, setChildren, wg)
// 将根节点添加至结果集合中
retList = append(retList, t)
}
}
// 等待所有递归都结束再返回结果
wg.Wait()
return retList
}
func recursionFn[T any](list []*T, t *T, setChildren func(parent *T, children []*T), wg *sync.WaitGroup) {
defer wg.Done()
// 递归收集当前节点的子节点集合
children := getChildList(list, t)
if children != nil && len(children) > 0 {
setChildren(t, children)
}
// 递归构建当前节点的子节点的子节点集合,直到遍历到叶子节点
for _, child := range children {
wg.Add(1)
go recursionFn(list, child, setChildren, wg)
}
}
//
// getChildList[T any]
// @Description 收集pid = t.id的子节点
// @Author xzx 2023-03-21 17:14:15
// @Param list
// @Param t
// @Return []*T
//
func getChildList[T any](list []*T, t *T) []*T {
tList := make([]*T, 0, 0)
id := GetId(t)
for _, n := range list {
if GetParentId(n) == id {
tList = append(tList, n)
}
}
// 如果是叶子节点 返回空数组
return tList
}