【Python机器学习】使用Matplotlib注解绘制树形图——构造注解树

绘制一颗完整的树需要一些技巧,我们虽然有x、y做表,但是如何放置所有的树节点是个问题。

我们必须知道有多少个叶节点,以便可以正确确定x轴的长度;我们还需要知道树有多少层,以便可以正确确定y轴的高度。

这里我们定义两个新函数,来获取叶节点的数目和树的层数:

def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        #测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs=numLeafs+getNumLeafs(secondDict[key])
        else:
            numLeafs=numLeafs+1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

这两个函数有相同的结构。这里使用的数据结构说明了如何在Python字典类型中存储树信息。

第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值。从第一个关键字处罚,我们可以遍历整棵树的所有子节点。使用Python提供的type()函数可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用getNumLeafs()函数。getNumLeafs()函数遍历整棵树,累计叶子节点的个数,则返回该数值。

第二个函数getTreeDepth()计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。

预设树信息:

def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

绘制函数:

def plotMidText(cntrPt,parentPt,txtString):
    #在父子节点间填充文本信息
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    #计算宽与高
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    #标记子节点属性值
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    #减少y偏移
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

函数createPlot()是我们的主函数,它调用了plotTree(),函数plotTree又依次调用了plotMidText()。绘制树状图的很多工作就是在plotTree()函数中完成的,它首先计算树的宽和高。全局变量plotTree.totslW存储树的宽度,plotTree.totslD存储树的深度,我们使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。

plotTree()也是一个递归函数。

树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间,而不仅仅是它子节点的中间。同时我们使用plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。

之后,绘制子节点具有的特征值,或者沿着此分支向下的数据实例必须具有的特征值。使用函数plotMidText()计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息。

然后,按比例减少全局变量plotTree.yOff,并标注此处将要绘制子节点,这些节点既可以是叶子节点也可以是判断节点,此处需要只保存绘制图形的轨迹。因为我们是自顶向下绘制图形,因此需要依次递减y坐标值,而不是递增y坐标值。然后程序采用函数getNumLeafs()和getTreeDepth()以相同的方式递归遍历整棵树,如果节点是叶子节点则在图形上画出叶子节点,否则递归调用plotTree()函数。在绘制了所有子节点之后,增加全局变量Y的偏移。

最后是createPlot(),它创建绘图区,计算树状图的全局尺寸,并调用递归函数plotTree()。

绘制效果,但是没有坐标轴标签:

如果变更字典,重新绘制:

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值