决策树03——使用matplotlib绘制树形图并测试算法

决策树02——决策树的构建中,我们将已经进行分类的数据存储在字典中,然而字典的表示形式非常不直观,也不容易理解,所以我们将字典中的信息绘制成树形图。

Matplotlib注解功能

  Matplotlib提供一个注解工具annotations,它可以在数据图形上添加文本注释。

  以下将使用Matplotlib的注解功能绘制树形图,它可以对文字着色,并提供多种形状以供选择,而且我们还可以反转箭头,将它指向文本框而不是数据点。

  新建名为treeplotter.py的新文件,将输入下面的程序代码:

# -*-coding=utf-8 -*-

#使用文本朱姐绘制树节点
import matplotlib.pyplot as plt

#定义文本框和箭头格式
#定义决策树决策结果的属性(决策节点or叶节点),用字典来定义
#下面的字典定义也可以写作 decisionNode = {boxstyle:’sawtooth‘,fc=’0.8‘}
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")       #决策节点,boxstyle为文本框类型,sawtooth是锯齿形,fc是边框内填充的颜色
leafNode = dict(boxstyle = "round",fc="0.8")                #叶节点,定义决策树的叶子结点的描述属性
arrow_args = dict(arrowstyle = "<-")                         #箭头格式

#绘制带箭头的注释
def plotNode(nodeTxt,centerPt,parentPt,nodeType):           #nodeTxt是显示的文本,centerPt是文本的中心点,parentPt是箭头的起点坐标,nodeType是一个字典 注解的形状
    createPlot.ax1.annotate(nodeTxt,xy = parentPt, xycoords = 'axes fraction',  #xy为箭头的起始坐标,0,0 is lower left of axes and 1,1 is upper right
                            xytext = centerPt,textcoords = 'axes fraction', #xytext为注解内容的坐标
                            va = "center",ha = "center",bbox = nodeType,arrowprops = arrow_args) #bbox注解文本框的形状,arrowprops是指箭头的形状

def createPlot():
    fig = plt.figure(1,facecolor='white')  #类似于matlab的figure,定义一个画布,其背景为白色
    fig.clf()                 #把画布清空
    createPlot.ax1 = plt.subplot(111,frameon=False) # createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图,
    plotNode(U'决策节点',(0.5,0.1),(0.1,0.5), decisionNode)
    plotNode(U'叶节点',(0.8,0.1),(0.3,0.8), leafNode)
    plt.show()

注意:以上程序运行时会出现中文变成小方框的现象,将以下几行代码添加到文件的开始处。

from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']  #指定默认字体
mpl.rcParams['axes.unicode_minus'] = False

在命令行输入:

In[70]: import treePlotter
Backend TkAgg is interactive backend. Turning interactive mode on.
In[71]: treePlotter.createPlot()

这里写图片描述

构造注解树

  我们虽然有x, y坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定x轴的长度,我们还需要知道树有多少层,以便可以正确的确定y轴的高度。
  这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶节点的输煤和树的层数。将下面的两个函数添加到treePlotter.py文件中。

#获取叶节点的数目和树的层次
def getNumLeafs(myTree):
    numLeaf = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ =='dict':         #测试节点的数据类型是否为字典 ,type(secondDict[key]) ==dict 也是可以的
            numLeaf += getNumLeafs(secondDict[key])
        else: numLeaf += 1
    return numLeaf

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth = 1
        if thisDepth > maxDepth :maxDepth = thisDepth
    return maxDepth

函数retrieveTree()输出预先存储的树信息,将 下面代码添加到文件treePlotter.py中:

def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'0',1:{'flippers':{0:'no',1:'yes'}}}},
                   {'no surfacing': {0: '0', 1: {'flippers': {0: {'head':{0:'no',1:'yes'}}, 1: 'no'}}}}
                   ]
    return listOfTrees[i]

在命令行中输入:

In[2]: import treePlotter
Backend TkAgg is interactive backend. Turning interactive mode on.
In[3]: treePlotter.retrieveTree(0)
Out[3]: 
{'no surfacing': {0: '0', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
In[4]: myTree = treePlotter.retrieveTree(0)
In[5]: treePlotter.getNumLeafs(myTree)
Out[5]: 
3
In[6]: treePlotter.getTreeDepth(myTree)
Out[6]: 
2

将下面代码添加到treePlotter.py中,注意前面已经定义了createPlot(),此时我们需要更新前面的代码。

#plotTree函数
#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

#自顶向下作图,绘制图形的x轴有效范围是0.0~1.0, y轴有效范围也是0.0~1.0
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)    #secondDict[key]的叶节点的数量
    depth = getTreeDepth(myTree)      #secondDict[key]的树深度
    print 'numLeafs,depth:',numLeafs,',',depth
    firstStr = myTree.keys()[0]
    # 全局变量plotTree.totalW 存储树的宽度,全局变量PlotTree.totalD 存储树的深度,使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。
    cntrPt = (plotTree.xOff +(1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) #注释1
    #标记子节点属性
    plotMidText(cntrPt,parentPt,nodeTxt)        #这一次循环中的cntrPt(即上式)为cbtrPt,parentPt为上一轮计算出的cntrPt
    plotNode(firstStr,cntrPt,parentPt,decisionNode)  #因还没画到叶节点,所以这里画的是决策节点,即此时筛选secondDict[key]还是字典
    secondDict = myTree[firstStr]
    #计算下一轮要用的y
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    #下面的循环中要使用的y
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #注释2

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])    #创建一个型为{'xticks': [], 'yticks': []}的字典
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))      #树的宽度
    plotTree.totalD = float(getTreeDepth(inTree))     #输的深度
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #定义节点位置的初始值
    plotTree(inTree,(0.5,1.0),'')     #(0.5,1.0)为初始化parentPt的值,注释3
    plt.show()

在命令行输入:

In[35]: reload(treePlotter)
Out[35]: 
<module 'treePlotter' from '/home/vickyleexy/PycharmProjects/Classification of contact lenses/treePlotter.py'>
In[36]: myTree = treePlotter.retrieveTree(0)
In[37]: myTree
Out[37]: 
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
In[38]: treePlotter.createPlot(myTree)
numLeafs,depth: 3 , 2
numLeafs,depth: 2 , 1

这里写图片描述

注释:
1.cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  在这行代码中,首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:
这里写图片描述
  其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,

  plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW*1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW*1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW*1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW.

2. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
这行代码中是需要的,当分支最后一个不是字典的时候,字典循环完需要返回上一层继续进行函数
例如:

In[40]: myTree['no surfacing'][3] = 'maybe'
In[41]: myTree
Out[41]: 
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
In[42]: treePlotter.createPlot(myTree)
numLeafs,depth: 4 , 2
numLeafs,depth: 2 , 1

这里写图片描述
 
3.plotTree(inTree,(0.5,1.0),'')
在这行代码中,对于plotTree函数参数赋值为(0.5, 1.0),因为开始的根节点并不用划线,因此父节点和当前节点的位置需要重合,利用2中的确定当前节点的位置便为(0.5, 1.0)

总结:利用这样的逐渐增加x的坐标,以及逐渐降低y的坐标能能够很好的将树的叶子节点数和深度考虑进去,因此图的逻辑比例就很好的确定了,这样不用去关心输出图形的大小,一旦图形发生变化,函数会重新绘制,但是假如利用像素为单位来绘制图形,这样缩放图形就比较有难度了

测试和存储分类器

程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类型。

#使用决策树的分类算法
def classify(inputTree,featLabels,testVec):    #testVec即为需要分类的数据
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)        #将标签字符串转换为索引
    print featIndex
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

在命令行输入:

In[19]: reload(trees)
Out[19]: 
<module 'trees' from '/home/vickyleexy/PycharmProjects/Classification of contact lenses/trees.py'>
In[20]: myDat,labels = trees.createDataSet()
In[21]: myDat
Out[21]: 
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
In[22]: labels
Out[22]: 
['no surfacing', 'flippers']
In[23]: myTree = trees.createTree(myDat,labels)
最好的特征,最好的信息增益: 0 , 0.419973094022
最好的特征,最好的信息增益: 0 , 0.918295834054
In[24]: myDat
Out[24]: 
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
In[25]: labels
Out[25]: 
['flippers']
In[26]: myDat,labels = trees.createDataSet()
In[27]: trees.classify(myTree,labels,[1,1])
0
1
Out[27]: 
'yes'
In[28]: trees.classify(myTree,labels,[1,0])
0
1
Out[28]: 
'no'

决策树的存储

为了节省时间,最好能够在每次执行分类时调用已经构造好的决策树,使用Python的pickle模块可以在磁盘上保存对象,并在需要的时候读取出来。

#使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

在命令行中输入:

In[29]: trees.storeTree(myTree,'classifierStorage.txt')
In[30]: trees.grabTree('classifierStorage.txt')
Out[30]: 
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值