Python3:《机器学习实战》之决策树算法(2)画个儿时的树

Python3:《机器学习实战》之决策树算法(2)画个儿时的树



前言:

  上一篇博文已经介绍了如何从数据集中创建树,然而字典的表示形式非常不易理解,而且直接绘制图形也比较困难。决策树的主要优点就是直观易于理解,如果不能将其直观的显示出来,优势便无从谈起,所以本片博文就介绍一下如何利用Matplotlib库来创建树形图。

Matplotlib注解:

  Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图像上添加文本注解。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支持带尖头的画线工具,使得我们可以在其他前挡的地方指向数据位置,并在此处添加描述信息,解释数据内容,如下图:

  打开文本编辑器,创建名为treePlotter.py的新文件,输入下面的程序代码。

代码实现:

'''
Created on Aug 14, 2017

@author: WordZzzz
'''

import matplotlib.pyplot as plt

#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    Function:   绘制带箭头的注解

    Args:       nodeTxt:文本注解
                centerPt:箭头终点坐标
                parentPt:箭头起始坐标
                nodeType:文本框类型

    Returns:    无
    """
    #在全局变量createPlot0.ax1中绘图
    createPlot0.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def createPlot0():
    """
    Function:   使用文本注解绘制树节点

    Args:       无

    Returns:    无
    """
    #创建一个新图形
    fig = plt.figure(1, facecolor='white')
    #清空绘图区
    fig.clf()
    #给全局变量createPlot0.ax1赋值
    createPlot0.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    #绘制第一个文本注解
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    #绘制第二个文本注解
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    #显示最终绘制结果
    plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

输出结果:

>>> import treePlotter
>>> treePlotter.createPlot0()
 
 
  • 1
  • 2

  程序结果如图所示,我们也可以改变函数plotNode(),观察图中x、y的位置如何变化。

构造注解树:

  绘制一颗完整的树需要技巧,虽然我们有坐标,但是如何放置所有的树节点却是个问题。所以我们需要知道有多少个叶节点来确定x轴长度;haixuyao知道有多少层来确定y轴的高度。

代码实现:

def getNumLeafs(myTree):
    """
    Function:   获取叶节点的数目

    Args:       myTree:树信息

    Returns:    numLeafs:叶节点的数目
    """
    #初始化叶节点数目
    numLeafs = 0
    #第一个关键字为第一次划分数据集的类别标签,附带的取值表示子节点的取值
    firstStr = myTree.keys()[0]
    #新的树,相当于脱了一层皮
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #判断子节点是否为字典类型
        if type(secondDict[key]).__name__=='dict':
            #是的话表明该节点也是一个判断节点,递归调用getNumLeafs()函数
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    #返回叶节点数目
    return numLeafs

def getTreeDepth(myTree):
    """
    Function:   获取树的层数

    Args:       myTree:树信息

    Returns:    maxDepth:最大层数
    """
    #初始化最大层数
    maxDepth = 0
    #第一个关键字为第一次划分数据集的类别标签,附带的取值表示子节点的取值
    firstStr = myTree.keys()[0]
    #新的树,相当于脱了一层皮
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #判断子节点是否为字典类型
        if type(secondDict[key]).__name__=='dict':
            #是的话表明该节点也是一个判断节点,递归调用getTreeDepth()函数
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    #返回最大层数
    return maxDepth
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

  retrieveTree()主要用于测试,返回预定义的树结构。

def retrieveTree(i):
    """
    Function:   创建树

    Args:       i:要输出的树在里列表中的位置

    Returns:    listOfTrees[i]:输出预先存储的树
    """
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

输出结果:

>>> reload(treePlotter)
<module 'treePlotter' from 'E:\\机器学习实战\\mycode\\Ch03\\treePlotter.py'>
>>> treePlotter.retrieveTree(1)
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
>>> treePlotter.retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)
2
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

报(python2.x与python3.x的差异):

TypeError: 'dict_keys' object does not support indexing
 
 
  • 1

解决(强制类型转换): 
Probably this was written with python2.x (when d.keys() returned a list). With python3.x, d.keys() returns a dict_keys object which behaves a lot more like a set than a list. As such, it can’t be indexed.

The solution is to pass list(d.keys()) (or simply list(d)) to shuffle.

  我们需要重新编写createPlot()函数,在createPlot0()的基础上进行完善。

代码实现:

def plotMidText(cntrPt, parentPt, txtString):
    """
    Function:   在父子节点间填充文本信息

    Args:       cntrPt:树信息
                parentPt:父节点坐标
                txtString:文本注解

    Returns:    无
    """
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    """
    Function:   创建数据集和标签

    Args:       myTree:树信息
                parentPt:箭头起始坐标
                nodeTxt:文本注解

    Returns:    无
    """
    #计算树的宽
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    #计算树的高
    depth = getTreeDepth(myTree)
    #第一个关键字为第一次划分数据集的类别标签,附带的取值表示子节点的取值
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    #下一个节点的位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #计算父节点和子节点的中间位置,并在此处添加简单的文本信息
    plotMidText(cntrPt, parentPt, nodeTxt)
    #绘制此节点带箭头的注解
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    #新的树,相当于脱了一层皮
    secondDict = myTree[firstStr]
    #按比例减少全局变量plotTree.yOff
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    #
    for key in secondDict.keys():
        #判断子节点是否为字典类型
        if type(secondDict[key]).__name__=='dict':
            #是的话表明该节点也是一个判断节点,递归调用plotTree()函数 
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            #不是的话更新x坐标值
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            #绘制此节点带箭头的注解
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            #绘制此节点带箭头的注解
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    #按比例增加全局变量plotTree.yOff
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    """
    Function:   使用文本注解绘制树节点

    Args:       inTree:

    Returns:    无
    """
    #创建一个新图形
    fig = plt.figure(1, facecolor='white')
    #清空绘图区
    fig.clf()
    #创建一个字典
    axprops = dict(xticks=[], yticks=[])
    #给全局变量createPlot.ax1赋值
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    #取得叶节点数目
    plotTree.totalW = float(getNumLeafs(inTree))
    #取得树最大层
    plotTree.totalD = float(getTreeDepth(inTree))
    #设置起点值
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    #绘制数
    plotTree(inTree, (0.5,1.0), '')
    #显示最终绘制结果
    plt.show()
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84

输出结果:

>>> reload(treePlotter)
<module 'treePlotter' from 'E:\\机器学习实战\\mycode\\Ch03\\treePlotter.py'>
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.createPlot(myTree)
 
 
  • 1
  • 2
  • 3
  • 4

输出效果如下图所示:

  接着按照命令更改字典,重新绘制树形图。

输出结果:

>>> myTree['no surfacing'][3] = 'maybe'
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
>>> treePlotter.createPlot(myTree)
 
 
  • 1
  • 2
  • 3
  • 4

系列教程持续发布中,欢迎订阅、关注、收藏、评论、点赞哦~~( ̄▽ ̄~)~

完的汪(∪。∪)。。。zzz

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值