Python3:《机器学习实战》之决策树算法(2)画个儿时的树
前言:
上一篇博文已经介绍了如何从数据集中创建树,然而字典的表示形式非常不易理解,而且直接绘制图形也比较困难。决策树的主要优点就是直观易于理解,如果不能将其直观的显示出来,优势便无从谈起,所以本片博文就介绍一下如何利用Matplotlib库来创建树形图。
Matplotlib注解:
Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图像上添加文本注解。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支持带尖头的画线工具,使得我们可以在其他前挡的地方指向数据位置,并在此处添加描述信息,解释数据内容,如下图:
打开文本编辑器,创建名为treePlotter.py的新文件,输入下面的程序代码。
代码实现:
'''
Created on Aug 14, 2017
@author: WordZzzz
'''
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
Function: 绘制带箭头的注解
Args: nodeTxt:文本注解
centerPt:箭头终点坐标
parentPt:箭头起始坐标
nodeType:文本框类型
Returns: 无
"""
createPlot0.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def createPlot0():
"""
Function: 使用文本注解绘制树节点
Args: 无
Returns: 无
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot0.ax1 = plt.subplot(111, frameon=False)
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
输出结果:
>>> import treePlotter
>>> treePlotter.createPlot0()
程序结果如图所示,我们也可以改变函数plotNode(),观察图中x、y的位置如何变化。
构造注解树:
绘制一颗完整的树需要技巧,虽然我们有坐标,但是如何放置所有的树节点却是个问题。所以我们需要知道有多少个叶节点来确定x轴长度;haixuyao知道有多少层来确定y轴的高度。
代码实现:
def getNumLeafs(myTree):
"""
Function: 获取叶节点的数目
Args: myTree:树信息
Returns: numLeafs:叶节点的数目
"""
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
"""
Function: 获取树的层数
Args: myTree:树信息
Returns: maxDepth:最大层数
"""
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
retrieveTree()主要用于测试,返回预定义的树结构。
def retrieveTree(i):
"""
Function: 创建树
Args: i:要输出的树在里列表中的位置
Returns: listOfTrees[i]:输出预先存储的树
"""
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
输出结果:
>>> reload(treePlotter)
<module 'treePlotter' from 'E:\\机器学习实战\\mycode\\Ch03\\treePlotter.py'>
>>> treePlotter.retrieveTree(1)
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
>>> treePlotter.retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)
2
报(python2.x与python3.x的差异):
TypeError: 'dict_keys' object does not support indexing
解决(强制类型转换):
Probably this was written with python2.x (when d.keys() returned a list). With python3.x, d.keys() returns a dict_keys object which behaves a lot more like a set than a list. As such, it can’t be indexed.
The solution is to pass list(d.keys()) (or simply list(d)) to shuffle.
我们需要重新编写createPlot()函数,在createPlot0()的基础上进行完善。
代码实现:
def plotMidText(cntrPt, parentPt, txtString):
"""
Function: 在父子节点间填充文本信息
Args: cntrPt:树信息
parentPt:父节点坐标
txtString:文本注解
Returns: 无
"""
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
"""
Function: 创建数据集和标签
Args: myTree:树信息
parentPt:箭头起始坐标
nodeTxt:文本注解
Returns: 无
"""
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
"""
Function: 使用文本注解绘制树节点
Args: inTree:
Returns: 无
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
输出结果:
>>> reload(treePlotter)
<module 'treePlotter' from 'E:\\机器学习实战\\mycode\\Ch03\\treePlotter.py'>
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.createPlot(myTree)
输出效果如下图所示:
接着按照命令更改字典,重新绘制树形图。
输出结果:
>>> myTree['no surfacing'][3] = 'maybe'
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
>>> treePlotter.createPlot(myTree)
系列教程持续发布中,欢迎订阅、关注、收藏、评论、点赞哦~~( ̄▽ ̄~)~
完的汪(∪。∪)。。。zzz