绘制一颗完整的树需要一些技巧,我们虽然有x、y做表,但是如何放置所有的树节点是个问题。
我们必须知道有多少个叶节点,以便可以正确确定x轴的长度;我们还需要知道树有多少层,以便可以正确确定y轴的高度。
这里我们定义两个新函数,来获取叶节点的数目和树的层数:
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
#测试节点的数据类型是否为字典
if type(secondDict[key]).__name__=='dict':
numLeafs=numLeafs+getNumLeafs(secondDict[key])
else:
numLeafs=numLeafs+1
return numLeafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
这两个函数有相同的结构。这里使用的数据结构说明了如何在Python字典类型中存储树信息。
第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值。从第一个关键字处罚,我们可以遍历整棵树的所有子节点。使用Python提供的type()函数可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用getNumLeafs()函数。getNumLeafs()函数遍历整棵树,累计叶子节点的个数,则返回该数值。
第二个函数getTreeDepth()计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。
预设树信息:
def retrieveTree(i):
listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
绘制函数:
def plotMidText(cntrPt,parentPt,txtString):
#在父子节点间填充文本信息
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.axl.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
#计算宽与高
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
#标记子节点属性值
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
#减少y偏移
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
函数createPlot()是我们的主函数,它调用了plotTree(),函数plotTree又依次调用了plotMidText()。绘制树状图的很多工作就是在plotTree()函数中完成的,它首先计算树的宽和高。全局变量plotTree.totslW存储树的宽度,plotTree.totslD存储树的深度,我们使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。
plotTree()也是一个递归函数。
树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间,而不仅仅是它子节点的中间。同时我们使用plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。
之后,绘制子节点具有的特征值,或者沿着此分支向下的数据实例必须具有的特征值。使用函数plotMidText()计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息。
然后,按比例减少全局变量plotTree.yOff,并标注此处将要绘制子节点,这些节点既可以是叶子节点也可以是判断节点,此处需要只保存绘制图形的轨迹。因为我们是自顶向下绘制图形,因此需要依次递减y坐标值,而不是递增y坐标值。然后程序采用函数getNumLeafs()和getTreeDepth()以相同的方式递归遍历整棵树,如果节点是叶子节点则在图形上画出叶子节点,否则递归调用plotTree()函数。在绘制了所有子节点之后,增加全局变量Y的偏移。
最后是createPlot(),它创建绘图区,计算树状图的全局尺寸,并调用递归函数plotTree()。
绘制效果,但是没有坐标轴标签:
如果变更字典,重新绘制: