目标:显示每层节点数不定的递归树,不同的类别用不同的颜色表示
所用python库:matplotlib
1. 初始化窗口
步骤:
- 计算所需画面大小,包括图幅高度height 和图幅宽度width。图幅宽度主要取决于叶子节点数目。a. 首先需要计算叶子节点的数目,b. 然后需要定义node的半径radius和同一层级两个node之间的距离nodeInterval,本次设置nodeInterval为radius的n倍,n可以自己给定。最后图幅所需要的宽度width=(treeWidth + 1) * nodeInterval。默认使用正方形画布,则height = width。最后用计算出来的height, width设置画布x轴和y轴的范围。
- 计算层间距disVec。首先计算树的深度maxDeepth,然后用图幅高度height除以树的深度即可得出每层之间的纵向距离,即disVec = height/(maxDeepth+1);
- 最后计算root的xy坐标:xStart, yStart。root 位于图幅正中心,距离页面顶部一个 disVec 距离。
- 函数返回值:返回根节点坐标、node的半径radius 和 树的纵向深度间隔disVec(用于递归画下一层节点)。因为是递归构建,所以当前子树也就是当前父节点一直在变化,所以需要事先计算好根节点,作为递归的传入参数。
def creatWindow(self, radius = 300):
'''
———————— O ————————
————O———— ————O————
-O--O--O- -O--O--O-
radius = 2 #节点的半径
nodeInterval = 4 * radius #定义为两个节点之间的距离,即--
'''
#定义节点间距
n = 1
nodeInterval = n * radius
# 定义图幅长宽,默认为正方形图幅,height=width
tree = self.tree
treeWidth = self.getNumLeafs(tree)
treeDepth = self.getMaxDepth(tree)
width = (treeWidth + 1) * nodeInterval
height = width
aspect = height/width
fig = plt.figure(figsize=(11, 11*aspect))
plt.xlim(0, width)
plt.ylim(0, height)
#定义层间距
disVec = height/(treeDepth+1)
#给定根节点坐标
xStart = int(width/2)
yStart = height - disVec
return fig, xStart, yStart, radius, width, disVec
重点:计算树的宽度和深度。
- 计算树的宽度:只要存在children,就进入下一层,直到遇到叶子节点. 每遇到一个叶子节点计数就加1,即numLeafs +=1. 遍历完整棵树之后,numLeafs就是树的宽度。
def getNumLeafs(self, tree):
numLeafs = 0
children = tree.children
numChildren = len(children)
if numChildren > 0:
for childNode in children:
branchNumLeafs = self.getNumLeafs(childNode)
numLeafs += branchNumLeafs
else:
numLeafs += 1
return numLeafs
- 计算树的深度:只要存在children,就进入一下层,同时深度值加1,即:深度 = 1 + 子树深度。当遇到叶子节点是,返回深度值1。同一个parent下面的每个child, 遍历计算深度后,只取他们中最大的深度值。
def getMaxDepth(self, tree):
''' 如果只有根节点,最大深度为1;
如果有子节点,maxDepth = 1 + getHeight(child), 1 为根节点自身深度
遍历所有子节点,取最深的那个深度
'''
maxDepth = 0
children = tree.children
numChildren = len(children)
if numChildren > 0:
for childNode in children:
branchDepth = 1 + self.getMaxDepth(childNode)
if branchDepth > maxDepth:
maxDepth = branchDepth
else:
maxDepth = 1
return maxDepth
2. 画递归树
思路:
整个递归树包含 节点 和 边 两类。当我们把每个节点看作一棵小树,一颗只包含边的小树。这个问题就简化为如何画出一个节点及其存在可能的边。依然是递归的方法。
过程:
- 画当前节点;
- 判断当前是否存在children,如果存在就画边,负责就是叶子节点,不画边。
- 更新节点,直到叶节点。
难点:确定子节点位置坐标。
'''
O
O O
-O--O--O- -O--O--O-
'''
如上红色父节点存在三个子节点,每个节点占据父节点空间的1/3, 即:
childWidth = parentWidth/numChildren,则它们x轴坐标分别为:
childWidth *(0+1/2);
childWidth *(1 + 1/2),;
childWidth *(2+1/2)。
因此可以得出,子节点位置坐标为childWidth *(i+1/2) =parentWidth/i * *(i+1/2) , i为子节点个数。
对于左子树,按照上述计算即可。但是对于右子树,要根据父节点的左侧范围计算坐标,上述例子右侧节点的左侧范围为图幅的1/2处,所以子节点的坐标为:parentWidth *1 + childWidth *(i+1/2).
最后综合左、右子树,最终的子节点的坐标计算公式为:parentWidth * j + parentWidth / i*(i+1/2)。其中,j为上一层父节点的编号, i为子节点个数。
画递归树的函数draw 一定要单独写一个函数,因为当前节点一直在变。
def draw(self, tree, xStart, yStart, radius, width, disVec, j=0):
#先画当前节点
self.drawNode(tree, xStart, yStart, radius)
#如果有子节点,递归画子节点
children = tree.children
numChildren = len(children)
if numChildren > 0:
childBrachWidth = width/numChildren
for i, childNode in enumerate(children):
xEnd = width * j + childBrachWidth * (i + 0.5)
yEnd = yStart - disVec
self.drawEdge(xStart, yStart, xEnd,yEnd)
self.draw(childNode, xEnd, yEnd, radius, childBrachWidth, disVec, i)
def drawNode(self, node, x, y, radius=20):
name = node.data.componentName
color = self.colors[node.data.label]
plt.scatter(x,y,s = radius, c=color, edgecolors=color, marker="o")
plt.text(x-radius/3, y-radius, name, fontsize=13)
def drawEdge(self, xStart, yStart, xEnd, yEnd):
x = (xStart, xEnd)
y = (yStart, yEnd)
plt.plot(x,y,'g-')
3. 最后将上述函数整理到show_Tree类中,并用showTree函数调用
class show_Tree:
def __init__(self, tree, colors = None):
self.tree = tree
self.colors = colors
def showTree(self):
tree = self.tree
#1. 构建窗口
_, xStart, yStart, radius, width, disVec = self.creatWindow()
#2. 画树
self.draw(tree, xStart, yStart, radius, width, disVec)
#3. 保存即显示
f = plt.gcf()
plt.legend()
#窗口最大化再保存比较清楚
plt.show()
f.savefig('splitTree.png')
plt.close()
def get_colour_name(requested_colour):
try:
closest_name = actual_name = webcolors.rgb_to_name(requested_colour)
except ValueError:
closest_name = closest_colour(requested_colour)
actual_name = None
return actual_name, closest_name
def randomColor(labels):
num = len(labels)
colors = [colorsys.hsv_to_rgb(i/num, 1, 1) for i in range(num)]
colors = [(int(c[0]) *255, int(c[1]) *255, int(c[2]) *255) for c in colors]
return colors
labels = []#### the category of component
colors = randomColor(labels)
colorDict = {key:get_colour_name(value)[0] for key,value in zip(labels, colors)}
p = show_Tree(tree, colorDict)
p.showTree()
结果展示: