《机器学习实战》决策树可视化详解

最近在艰难的啃食机器学习,没有什么基础,又想把每一句都弄懂,这里和大家分享下,中间有一部分因为版本不同(小白用的3.0,书上好像2.0),已经做了修正。还留有一些疑问没有想明白,希望可以和大家交流下:1)书上的plotTree.totalW=float(getNumLeafs(inTree))作为全局变量使用,但是我的程序识别不出来它是全局变量,所以我的办法是在plotTree中又重新定义了一遍plotTree.totalW=float(getNumLeafs(myTree)),不知道还有没有什么别的好一点的方法
2)画图的坐标定位,用脑子想,真的好难整,为什么一开始cntrPt,parentPt坐标定位都是(0.5,1.0)而且parentPt好像一直是(0.5,1.0)希望大家可以多多指教呀

import matplotlib.pyplot as plt

#树信息
def receieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

#程序3-5,使用文本注解绘制树节点
decisionNode=dict(boxstyle="sawtooth",fc='0.8')
#创建一个字典,注解框的边缘是波浪线,fc是颜色深度
leafNode=dict(boxstyle='round4',fc='0.8')
arrow_args=dict(arrowstyle="<-")
#箭头样式



#执行实际的绘图功能
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #nodeTxt节点注释文本,centerPt节点中心坐标,parentPt起始坐标,nodeType节点类型
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',
                            va='center',ha="center",bbox=nodeType,arrowprops=arrow_args)
    #annotate,提供注解工具;coords表示坐标;axes fraction表示(0,0)是左下角,(1,1)是右上角

#程序3-6获取叶节点的数目和树的层数
#首先获取总的叶子节点数
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    #获取字典的一个key,也就是判断节点标签,list(myTree.keys())[0]以列表的形式返回,且为列表的第一个元素作为关键字
    secondDict=myTree[firstStr]
    #获取key的值,有可能是字典类型的,也有可能是非字典类型的
    for key in secondDict.keys():
        # 当secondDict中含有关键字key时
        if type(secondDict[key])==dict :
            #如果secondDict关键字的值是字典类型的
            numLeafs=1+getNumLeafs(secondDict[key])
            #叶子节点数+1,因为secondDict[key]是字典类型的,所以可以将其作为树,继续遍历叶子节点
        else: numLeafs+=1
        #叶子节点数+1
    return numLeafs
    #返回值很重要,递归函数继续进行的关键



#再次要获取树的深度
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    #获取树的关键字,也就是父节点(判断节点)
    secondDict=myTree[firstStr]
    #获取关键字的值
    for key in secondDict.keys():
        # 当secondDict中含有关键字key时
        if type(secondDict[key])==dict:
            #如果获取的值是一个字典类型,深度+1,将secondDict[key]作为树,继续进行遍历
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
        #获取最大的深度,
    return maxDepth
    #返回值是遍历的关键



#程序3-7 plotTree函数
#创建绘图区,计算树形图的全局尺寸,并调用递归函数plotTree
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white',)
    #创建一个新图形,facecolor背景色
    fig.clf()
    #清空绘图区
    axprops = dict(xticks=[], yticks=[])
    #定义横纵坐标轴,无内容
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    #创建一个子图,并在子图上继续绘制subplot(nrows, ncols, index, **kwargs),将图形分为1行,1列,index指定获取第1个区域
    #ax1表示一个子图,ax2 = ax1.twinx() 让两个子图的X坐标轴一样,参考链接https://blog.csdn.net/htuhxf/article/details/82986440
    #frameon是否绘制边框
    plotTree.totalW=float(getNumLeafs(inTree))
    #获取树的宽度
    plotTree.totalD=float(getTreeDepth(inTree))
    #获取树的深度
    plotTree.x=-0.5/plotTree.totalW;plotTree.y=1.0
    #使用plotTree.x以及plotTree.y来追踪节点的位置,以及放置下一个节点的位置
    #为了防止图形右偏,将图形首先横坐标初始化,例如:totalW=3,第一个坐标为1/3,此时希望将文本标签能够在1/3处均匀分布,所以1/3-0.5/3,纵坐标初始化从1开始,自上向下画
    plotTree(inTree,(0.5,1.0),'')
    #调用plotTree函数
    plt.show()

#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    #cntrPt文本中心点,parentPt指向文本中心点【横坐标,纵坐标】,文本字符
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
    #文本字符,说明坐标

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    #获取叶子节点数和树的深度
    firstStr=list(myTree.keys())[0]
    #获取树的第一个关键字
    plotTree.totalW = float(getNumLeafs(myTree))
    # 获取树的宽度,叶子节点的数目
    plotTree.totalD = float(getTreeDepth(myTree))
    # 获取树的深度,判断节点的数目
    cntrPt=(plotTree.x+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y)
    #cntrPt文本中心点,判断点的位置
    plotMidText(cntrPt,parentPt,nodeTxt)
    #把键画在分支上
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    #执行绘图功能
    secondDict=myTree[firstStr]
    #关键字的值
    plotTree.y=plotTree.y-1.0/plotTree.totalD
    #需要依次递减Y的坐标,自上往下画
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            #如果是字典则为判断节点
            plotTree(secondDict[key],cntrPt,str(key))
            #继续进行遍历
        else:
            #否则画出叶子节点
            plotTree.x=plotTree.x+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.x,plotTree.y),cntrPt,leafNode)
            plotMidText((plotTree.x,plotTree.y),cntrPt,str(key))
    plotTree.y=plotTree.y+1.0/plotTree.totalD
    #在绘制了所有的叶子节点后,增加全局变量Y的偏移

画出来的图是这个样子滴
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值