机器学习实战--决策树(二)

                              使用Matplotlib注解绘制树形图

annotation是注解工具,注解功能可以对文字着色,并提供多种形状以供选择,还可以反转箭头。创建名为treePlotter.py的新文件。

使用文本注解绘制树节点:

 

#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
#绘制带箭头的注释
#nodeTxt:节点文本
#centerPt:文本框的中心位置
#parentPt:箭头线的起点
#nodeType:文本框类型
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    #xycoords和textcoords分别代表坐标系,va和ha分别指文字在框中横向和纵向的位置
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
                            xytext=centerPt,textcoords='axes fraction',\
                            va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)
def createPlot():
    #新建画布并设置画布的背景
    fig = plt.figure(1,facecolor='white')
    #清空
    fig.clf()
    #frameon是否绘制坐标轴的矩形
    createPlot.ax1 = plt.subplot(111,frameon=False)
    plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

得到的结果如下:

此外绘制图形必须知道叶子节点个数,以便确定x的长度,还需要知道树有多少层,以便确定y'轴高度。定义两个函数getNumLeafs()和getTreeDepth():

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs +=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

 源码较为简单,不做深度解析。值得注意的是获取键之后不能直接用索引来获取某一个键,需要先用list函数序列化,否则会报错。

为了节省从数据中创建树字典的时间,使用函数retrieveTree输出预先存储的树信息。代码如下:

#构建测试数据集
def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
            {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

以下为测试过程: 

>>>myTree = treePlotter.retrieveTree(0)

>>>treePlotter.getNumLeafs(myTree)

3

接下来组合一起生成完整的决策树:

#根据字典绘制图形
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0]+cntrPt[0])/2.0
    yMid = (parentPt[1]+cntrPt[1])/2.0
    createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.y0ff = plotTree.y0ff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.x0ff = plotTree.x0ff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
            plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
    plotTree.y0ff = plotTree.y0ff + 1.0/plotTree.totalD
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.x0ff = -0.5/plotTree.totalW
    plotTree.y0ff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

测试的结果如下:

 

以上代码主要是定位节点的坐标,操作比较繁琐。具体原理请参考定位详解。关于 annotate函数的详解请参考annotate

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值