用matplotlib注解绘制决策树,机器学习实战第三章

这段代码纠结了很久,也调了很久,算是弄懂了,感觉还是应该写下来,以后可以看看,也方便有同样疑惑的同学节省时间

首先贴上treePlotter.py的代码,这个自定义模块实现了所有的函数用于画决策树

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle='sawtooth', fc='10')
leafNode = dict(boxstyle='round4',fc='0.8')
arrow_args = dict(arrowstyle='<|-')



def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
                             xytext=centerPt,textcoords='axes fraction',\
                             va='center', ha='center',bbox=nodeType,arrowprops\
                             =arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict:
        if(type(secondDict[key]).__name__ == 'dict'):
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return  numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict:
        if(type(secondDict[key]).__name__ == 'dict'):
            thisDepth = 1+getTreeDepth((secondDict[key]))
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    #预先设置树的信息
    listOfTree = [{'no surfacing':{0:'no', 1:{'flipper':{0:'no', 1:'yes'}}}},
                  {'no surfacing':{0:'no', 1:{'flipper':{0:{'head':{0:'no', 1:'yes'}},1:'no'}}}},
                  {'a1':{0:'b1', 1:{'b2':{0:{'c1':{0:'d1',1:'d2'}}, 1:'c2'}}, 2:'b3'}}]
    return listOfTree[i]

def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks = [], yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0), '')
    plt.show()

def plotMidText(cntrPt, parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,\
              plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict:
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\
                     cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
这段代码的核心内容是plotTree()函数,也是纠结最久的一块,尤其是计算plotTree.xOff的值

下面贴上主模块的代码

import treePlotter as tp

myTree = tp.retrieveTree(2)
tp.createPlot(myTree)
函数retriveTree() 构造了三个决策树,用字典表示,多构造几个树可以多调试几次,对函数更加清晰

运行结果如下



下面分段解析模块treePlotter.py

第一部分

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle='sawtooth', fc='10')
leafNode = dict(boxstyle='round4',fc='0.8')
arrow_args = dict(arrowstyle='<|-')



def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
                             xytext=centerPt,textcoords='axes fraction',\
                             va='center', ha='center',bbox=nodeType,arrowprops\
                             =arrow_args)

上面定义的字典decisionNode和leafNode都是作为参数传递给函数nodeType,进而传递给内置函数annotate()的bbox参数

决定了注释文本框的特征。annotate函数本来是用来给图做注释的,这里用来画决策树,所以可以看出arrow_args字典中把

箭头类型定义成反过来的箭头,即"<|-",当然定义成"<-"也一样,为什么呢,因为普通注释是: ’注释文本‘——>’图上某点‘,而决策树是:’某个出发点(父节点的坐标)‘——>’子节点(包含文本框)‘

关于annotate函数的详细解释,点击打开链接


第二部分:

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict:
        if(type(secondDict[key]).__name__ == 'dict'):
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return  numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict:
        if(type(secondDict[key]).__name__ == 'dict'):
            thisDepth = 1+getTreeDepth((secondDict[key]))
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth
getNumLeafs()函数获得决策树的叶节点数,fistStr获得字典的键,代表树根,也代表样本集中的某个特征,secondDict则是字典的值,可以是另一个字典,代表分支,也可以是一个单值,代表叶子。例如

{'no surfacing':{0:'no', 1:{'flipper':{0:'no', 1:'yes'}}}}
fistStr就是'no surfacing',secondDict是{0 : 'no', 1:{'flipper' : {0 : 'no', 1:'yes'}}}

for循环就是递归了。对于叶子节点和分支分别执行不同的操作,求叶子节点的话,就是其子树的叶子节点数与直接叶子节点数之和,求深度的话,如果某个节点是叶子节点,那么深度是一,如果是子树根节点,那么是子树的深度加一,然后取最大值即可,这也意味着,最低层的元素深度是0,例如上面那个图,深度是3


第三部分:

def retrieveTree(i):
    #预先设置树的信息
    listOfTree = [{'no surfacing':{0:'no', 1:{'flipper':{0:'no', 1:'yes'}}}},
                  {'no surfacing':{0:'no', 1:{'flipper':{0:{'head':{0:'no', 1:'yes'}},1:'no'}}}},
                  {'a1':{0:'b1', 1:{'b2':{0:{'c1':{0:'d1',1:'d2'}}, 1:'c2'}}, 2:'b3'}}]
    return listOfTree[i]

这个就是提前构造一些决策树,方便用于测试


第四部分

def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks = [], yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0), '')
    plt.show()
1. fig = plt.figure(1,facecolor = 'white')定义了一个框架,序号是1,背景色是白色

2.xticks是一个列表,其中的元素就是x轴上将显示的坐标,yticks是y轴上显示的坐标,这里空列表则不显示坐标。

3.

createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)

这里定义一个子图窗口,第一个参数xyz含义是,将框架划分为x行y列窗口,ax1代表其第z个窗口。frameon = False将隐藏坐标轴

4. plotTree.totalW是决策树的叶子树,也代表宽度,plotTree.totalD是决策树的深度,xOff和yOff将和下面的函数一起解释,

然后调用核心函数plotTree(),之后显示

第五部分:

def plotMidText(cntrPt, parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)
这里还有一个小函数用于在两个坐标点cntrPt,parentPt的中点写一段文本,内容为txtString,其实我觉得完全可以

xMid = (parentPt[0] + cntrPt[0])/2.0  这样不会容易理解一点吗        /笑哭 /笑哭

第六部分:

主菜来了

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,\
              plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict:
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\
                     cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

先解决那两个式子

plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0

plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW
注意一个总的原则:每个节点,都要在以其为根的子树的叶节点的中间,且所有叶节点在x方向是均匀分布的。也就是说,我可以这么来确定每个节点的x坐标位置,即首先根据叶节点总个数,分配好每个叶节点的位置,然后根据叶节点确定其他非叶节点。例如若以节点a为根的子树的三个叶节点分别是0.3,0.4,0.5(显然这样的话,叶子的距离是0.1,总共有十个叶节点),那么a的x坐标为0.4

这里首先明白xOff代表的是刚刚画完的叶节点的x坐标,注意是叶节点,其他非叶节点也要用xOff来算,就是第二个式子,但是不改变xOff的值,只有真的画了一个叶节点,才改变xOff的值,在函数15行的地方可以看到。

yOff代表当前处理的决策树的层,最高层时yOff是1,每到达决策树下一层,yOff都要减少一个层距,一直到0,层距就是根据决策树深度等距离划分y轴.例如决策树深度为3,那么层距为1/3

画出的决策树,顶层到1,底层到0,但是左右两边并不是挨着边界,而是有一段距离,这里我将1/plotTree.totalW称为"叶距",即叶子节点数等分划分x轴后的每一段大小,称二分之一个"叶距"为"半叶距"。 那么决策树的最左边的叶节点距离图的左边界相隔一个"半叶距", 最右边的叶节点距离图的右边界相隔一个"半叶距"。将坐标轴显示出来的图如下:


这里"叶距"就是1/5 即0.2,"半叶距"就是0.1了。

(1)先来看第二个式子是怎么回事.我把第二个式子改成

plotTree.xOff+(1.0+float(numLeafs))*(1/2.0)*(1/plotTree.totalW)
用中文进一步表示为

已画叶节点的x坐标+(1.0+float(numLeafs))*半叶距

假设现在我的xOff指代叶节点b1的x坐标,即0.1,并且当前递归到了b2,也就是说该画b2了。那么怎么确定b2的x坐标(y坐标好确定,反正每下一层,yOff就减少1/3)。带入上面这个图,xOff为0.1,float(numLeafs)=3(因为b2有三个叶节点d1,d2,c2)

b2要在d1,d2,d3中间,那么b2与d2的x坐标相同了。d2与d1相差2个"半叶距",d1与b1相差2个"半叶距",所以总共相差4个"半叶距",而1+float(numLeafs)是4,带入上式正好可以算出b2的x坐标。

这是巧合吗,当然不是。假设有四个叶节点d1,d2,d3,d4,那么b2在d2,d3中间。所以b2与d2差一个"半叶距"。d2与b1相差四个“半叶距”,总共相差五个"半叶距",且1+float(numLeafs) = 1+4 = 5,符合。

不管b2有多少个叶节点,(1+float(numLeafs)都表示正在画的节点与已画叶节点之间的“半叶距”个数,这个递推式可以用归纳法证明,直接理解也行。所以上面的式子都可以算出b2的x坐标。

如果要画叶节点该怎么办呢,用15行的语句,如下。

plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
加一个叶距就ok了。注意要更新xOff的值了,因为画了一个新的叶节点。

那么怎么解释xOff的初始化呢,即createPlot()函数中的下列语句

plotTree.xOff = -0.5/plotTree.totalW
很好理解,我之前说了,最左边的叶节点与图的左边界,即x=0的直线,相差一个“半叶距”,那么为了使得
plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW

这个式子第一次也适用,则初始的xOff与第一个叶节点也要相差一个“叶距”,所以xOff相对y轴左移了一个“半叶距”

剩下的部分都很容易懂了,注意yOff在进入下一层时要减一个层距,跳出时要加回来。事实上这个决策树的画图顺序是深度优先搜索,所以会有返回上一层,如果不更新全局变量yOff,会出错。这段代码如下

 secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict:
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\
                     cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

很对称,是吧,从secondDict诞生那刻起,yOff就得减去层距,跳出for循环后,就得加回来。这样深度优先搜索完后,一个数就画出来了


  • 22
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值