机器学习算法二——决策树(2)(使用matplotlib绘制树形图)

在Python中使用Matplotlib注解绘制树形图

本节将学习如何编写代码绘制如下图所示的决策树。
在这里插入图片描述
1、Matplotlib注解
Matplotlib提供了一个注解工具annotations,非常有用,可以在数据图形上添加文本注释。注解通常用于解释数据的内容。
在这里插入图片描述

#使用文本注解绘制树节点
import matplotlib.pyplot as plt

#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #fc表示框填充色,0-1表示黑到白
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):#(节点的文字标注,节点中心位置,节点起点位置,节点属性)
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot():
        fig = plt.figure(1, facecolor='white') #创建一个新图形,背景颜色为白色
        fig.clf()  #并清空绘图区
        createPlot.ax1 = plt.subplot(111, frameon=False) #在figure1里面创建一个11列的子figure,并返回第一个实例
        plotNode('a decision node', (0.5,0.1),(0.1,0.5), decisionNode)
        plotNode('a leaf node', (0.8,0.1),(0.3,0.8),leafNode)#在绘图区上绘制两个代表不同类型的树节点
        plt.show()

def main():
    createPlot()
    
if __name__ == "__main__":
    main()

运行结果:
在这里插入图片描述

(1)createPlot.ax1.annotate()函数
原型: class matplotlib.axes.Axes()的成员函数annotate()。
作用: 为绘制的图上指定的数据点xy添加一个注释nodeTxt,xytext指定注释的位置,xycoords指定点xy坐标的类型,textcoords指定xytext的类型,xycoords和textcoords的取值如下:
‘figure points’:表示坐标原点在图的左下角的数据点
‘figure pixels’:表示坐标原点在图的左下角的像素点
‘figure fraction’:此时取值是小数,范围是([0, 1], [0, 1]),在图的最左下角时xy是(0,0), 最右上角是(1, 1),其他位置按相对图的宽高的比例取小数值
‘axes points’:表示坐标原点在图中坐标的左下角的数据点
‘axes pixels’:表示坐标原点在图中坐标的左下角的像素点
‘axes fraction’:类似‘figure fraction’,只不过相对图的位置改成是相对坐标轴的位置
‘data’:此时使用被注释的对象所采用的坐标系(这是默认设置),被注释的对象就是调用annotate这个函数那个实例,这里是ax1,是Axes类,采用ax1所采用的坐标系
‘offset points’:表示相对xy的偏移(以点的个数计),不过一般这个是用textcoords
‘polar’:极坐标类型,在直角坐标系下面也可以用,此时坐标含义为(theta, r)

参数arrowprops:连接数据点和注释的箭头类型,该参数是dictionary类型,含有一个名为arrowstyle的键,一旦指定该键就会创建一个class matplotlib.patches.FancyArrowPatch类的实例,该键取值可以是一个可用的arrowstyle名字的字符串,也可以是可用的class matplotlib.patches.ArrowStyle类的实例。
具体arrowstyle名字的字符串可以参考 http://matplotlib.org/api/patches_api.html#matplotlib.patches.FancyArrowPatch里面的class matplotlib.patches.FancyArrowPatch类的arrowstyle参数设置。

2、构造注解树
上面学习了如何绘制树节点,然而绘制一棵完整的树需要一些技巧。比如,如何放置所有的树节点就是个问题,下面将学习如何绘制整棵树。
我们需要知道有多少个叶节点,以便确定x轴的长度;
我们需要知道树有多少层,以便确定y轴的高度;
下面用两个函数 getNumLeafs()getTreeDepth() 分别获取叶节点的数目和树的层数。

#获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]#第一个关键字是第一次划分数据集的类别标签
    secondDict = myTree[firstStr] #第一个关键字后的字典
    for key in secondDict.keys(): #从第一个关键字出发,遍历整棵树的所有子节点
        if type(secondDict[key]).__name__ == 'dict': #测试节点的数据类型是否为字典
            numLeafs += getNumLeafs(secondDict[key]) #是字典类型,则递归调用getNumLeafs
        else:
            numLeafs += 1 #非字典类型则为叶子节点
    return numLeafs

#获取树的层数
def getTreeDepth(myTree): #计算遍历过程中遇到判断节点的个数
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict': #若为非叶子节点,则递归计算树的深度
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

注: {‘no surfacing’:{0:‘no’, 1:{‘flipper’:{0:‘no’, 1:‘yes’}}}}
fistStr就是’no surfacing’,secondDict是{0 : ‘no’, 1:{‘flipper’ : {0 : ‘no’, 1:‘yes’}}}

以下函数retrieveTree输出预先存储的树信息,避免了每次测试代码时都要从数据中创建树的麻烦:

#函数retrieveTree输出预先存储的树信息
def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no', 1:{'flippers':{0:'no',1:'yes'}}}},
                   {'no surfacing':{0:'no', 1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

执行结果:

print(retrieveTree(1))
myTree = retrieveTree(0)
print(getNumLeafs(myTree))
print(getTreeDepth(myTree))
>>>{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
4
3

(2) python2.x与python3.x之d.keys()返回类别的区别
python2.x dict.keys()返回list类型,可直接使用索引获取元素:myTree.keys()[0]
python3.x dict.keys()返回dict_keys类型,其性质类似集合(set) 而不是列表(list),因此不能使用索引获取其元素:list(myTree.keys())[0]
(3) type()函数
描述:如果你只有第一个参数则返回对象的类型,三个参数返回新的类型对象。
语法:type(object)
type(name, bases, dict)
参数:name ------ 类的名称
bases ------ 基类的元组
dict ------ 字典,类内定义的命名空间变量
返回值:一个参数返回对象类型,三个参数返回新的类型对象
实例
在这里插入图片描述
现在用前面学到的方法组合在一起,绘制一棵完整的树。在本文件中添加如下代码:

#plotTree函数
def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString) #在cntrPt和parentPt中点写一段文本,内容为txtString

def plotTree(myTree, parentPt, nodeTxt): #计算宽与高
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0] #第一个关键字是第一次划分数据集的类别标签
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #其他非叶节点也用xOff算,但是不改变xOff值,只有真的画了一个叶节点,才改变xOff的值
    plotMidText(cntrPt, parentPt, nodeTxt) #标记子节点属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)#(节点文本内容,子节点坐标,父节点坐标,节点属性)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #减少y偏移,1.0/plotTree.totalD是层距
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white') #定义了一个框架,序号是1,背景色是白色
    fig.clf()
    axprops = dict(xticks=[0.0 ,0.2, 0.4, 0.6, 0.8, 1.0], yticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) #列表 xticks是x轴上将显示的坐标,yticks是y轴上将显示的坐标,空列表则不显示坐标
    createPlot.ax1 = plt.subplot(111, frameon=True, **axprops)#定义一个子图窗口(11列第1个窗口) 隐藏坐标轴
    plotTree.totalW = float(getNumLeafs(inTree)) #是决策树的叶子数,也代表宽度
    plotTree.totalD = float(getTreeDepth(inTree)) #是决策树的深度
    plotTree.xOff = -0.5/plotTree.totalW; #xOff代表的是刚刚画完的叶节点的x坐标,注意是叶节点
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
    
def main():
    myTree = retrieveTree(0)
    createPlot(myTree)
    
if __name__ == "__main__":
    main()

这段代码中,我们首先分析createPlot(inTree)函数,全局变量plotTree.totalW存储树的宽度,plotTree.totalD存储树的深度。
树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间。
使用两个全局变量plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。
注意:绘制图形的x轴和y轴有效范围都是0.0到1.0。

plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0;

这是对plotTree.xOff和plotTree.yOff的初始化。1.0是总宽度,plotTree.totalW是叶子节点总数,1.0/plotTree.totalW则为两个相邻叶子节点间的距离,所以-0.5/plotTree.totalW为横坐标位于范围左侧半叶距处,纵坐标为页面最上端。

接下来需要重点分析:

plotTree(inTree, (0.5,1.0), '')

即分析plotTree(myTree, parentPt, nodeTxt)函数,这里给出母节点的坐标为(0.5,1.0),也就是整棵树的起点坐标。
注意,plotTree是个递归函数。

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

这是计算子节点的坐标,这里plotTree.xOff的值仍是刚刚初始化的值还未改变(即页面左侧半叶距),我们可以将(1.0 + float(numLeafs))/2.0/plotTree.totalW写成(1.0/2.0)(1/plotTree.totalW)+(1.0/2.0) float(numLeafs) * (1/plotTree.totalW),而(1.0/2.0)(1/plotTree.totalW)就是半叶距,(1.0/2.0) float(numLeafs) * (1/plotTree.totalW)的值为0.5,即页面中间的位置,这样横坐标就为0.5,纵坐标为1.0。

plotMidText(cntrPt, parentPt, nodeTxt) #在父子节点间填充文本信息
plotNode(firstStr, cntrPt, parentPt, decisionNode)#(节点文本内容,子节点坐标,父节点坐标,节点属性)

此时,父子节点是同一个点,即(0.5,1.0),而节点间的填充信息nodeTxt正好为‘ ’。接着用plotNode画出节点。

 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #减少y偏移

1.0/plotTree.totalD是层距,每向下一层,纵坐标减少一个层距。

for key in secondDict.keys():
       if type(secondDict[key]).__name__ == 'dict':
           plotTree(secondDict[key],cntrPt,str(key))
       else:
           plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
           plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
           plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

遍历下一层字典中的内容:
若为字典,即非叶节点,则调用plotTree递归画出节点,可以分析,节点坐标为

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

注意此时的树myTree已经变成了secondDict[key],自然numLeafs也发生了变化,是以key为顶点的树下的叶子数。cntrPt的横坐标在刚画完的节点的横坐标(即其父节点)的基础上,加上半叶距,加上叶子数占总叶子数距离的一半,节点文本即位键值key。

若为叶子节点,则在初始值的基础上加上一个叶间距的距离更新plotTree.xOff值,以后每发现一个叶子节点,就更新一次plotTree.xOff的值,并画出节点以及父子节点间的文本。

说明半叶距存在的原因:由于我们要考虑整棵树的布局,而在画图的过程中,保证整棵树处于页面中央,而又留有一定的页边距。若图宽度为1,共5个叶子节点,则以0.0坐标为起点,则5个叶子节点的横坐标为[0.0 0.2 0.4 0.6 0.8] , 左边顶格而右边空0.2,因此,若将整张图右移半叶距即0.1,问题就可解决。

 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

经过分析可以知道,此决策树的画图顺序是深度优先搜索,所以会有返回上一层。yOff在进入下一层时要减一个层距,跳出时要加回来。

tree0的树形图:
在这里插入图片描述

def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no', 1:{'flippers':{0:'no',1:'yes'}},3:'maybe'}},
                   {'no surfacing':{0:'no', 1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值