介绍
上节我们已经通过Matplotlib绘制了简单的标识图,本节我们将针对实际决策树进行绘制。
绘制需要
- 求树的宽度和深度
- 计算根节点和分支节点的连线中点 (我们选择在中点处进行绘制信息)
- 通过递归调用 展开绘制
- 递归过程中我们需要进行判断:结点是叶子还是另一颗子树的根节点
树宽、树深
绘制一棵完整的树需要一些技巧。我们虽然有x、y坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定x轴的长度.我们还需要确定树的深度,以便于确定y轴的高度。
"""
函数说明:
得到树的叶子结点个数
Parameters:
myTree:决策树
Return:
numLeafs:叶子结点个数
"""
def getNumLeafs(myTree):
numLeafs = 0 # 结点数目初始化
# firstStr = myTree.keys()[0]
# TypeError: 'dict_keys' object does not support indexing
# 原因:这是由于python3.6版本改进引起的。
# 解决方案:
# temp_keys = list(myTree.keys())
# firstStr = temp_keys[0]
# 在这里 只能取到第一个Key值 其他的key值嵌套在字典里 该方法识别不了 不过正是我们想要的
temp_keys = list(myTree.keys()) # mytree: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
firstStr = temp_keys[0] # 这里我们取到决策树的第一个key值
secondDict = myTree[firstStr] # 由于树的嵌套字典格式 我们通过第一个key得到了其value部分的另一个字典
for key in secondDict.keys(): # 取出第二字典的key 0和1
if type(secondDict[key]).__name__ == 'dict':
# 判断是否相应key的value是不是字典 是字典就不是叶子结点
# 继续调用本函数拆分该字典直到不是字典 即为叶子结点 进行记录
numLeafs += getNumLeafs(secondDict[key])
else: # 不是字典直接记录为叶子结点
numLeafs += 1
return numLeafs
"""
函数说明:
得到树的深度
Parameters:
myTree:决策树
Return:
maxDepth:树高
"""
def getTreeDepth(myTree):
maxDepth = 0
# firstStr = myTree.keys()[0]
firstStr = next(iter(myTree)) # 这里有第二种方法可以取到到第一个key值
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
# 与记录叶子结点数目类似 一个字典算是一层的代表(因为字典必有分支)
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
# 是叶子结点也给他记作一层 但是要注意 当同一层不单单只是叶子结点 有下层分支时
# 会将本来有两层的计数重置为一层 为了避免这一错误 我们有了下面的if判断
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
"""
函数说明:
模拟树的创建(自定义的方式)
为了避免每次调用都要通过数据集创建树的麻烦
Parameters:
i:哪个树
我们可以多模拟几个数来检测函数的可行性
Return:
listOfTrees[i]:第i个树
"""
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
查看结果:
print(getNumLeafs(retrieveTree(0))) # 3
print(getNumLeafs(retrieveTree(1))) # 4
print(getTreeDepth(retrieveTree(0))) # 2
print(getTreeDepth(retrieveTree(1))) # 3
当然既然是模拟树,大家可以创造一些更为复杂的决策树进行测试。
计算根节点和分支节点的连线中点
"""
函数说明:
在父子结点间填充文本信息
Parameters:
cntrPt,parentPt:用于计算标注位置(我们取父子连线的中点作为标注位置)
txtString:标注内容
Return:
None
"""
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# (parentPt[0]+cntrPt[0])/2.0 考虑:为什么不直接这样?
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
逻辑绘制
"""
函数说明:
使用文本注解绘制树节点
parameters:
nodeTxt:注释文段
centerPt:文本中心坐标
parentPt:箭头尾部坐标
nodeType:注释文本类型
Return:
无返回 执行annotate()画布
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
"""
函数说明:
绘制决策树---筹备
Parameters:
myTree:决策树
parentPt:父节点位置(在上节中是箭头尾部 箭头头部是子节点)父节点----->子节点
nodeTxt:标注信息
Special:
numLeafs:当前结点的叶子节点数(是在变的)
tatalW:树的总叶子数
Return:
None
"""
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 得到叶子结点计算树的宽度
depth = getTreeDepth(myTree) # 得到树深度
firstStr = list(myTree.keys())[0] # 得到根结点(父节点)注释内容
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)# 根节点位置
# 第一次看非常之疑惑 plotTree.xOff、plotTree.totalW、plotTree.yOff 三个突兀的东西就这样出现了
# 一开始以为是定义函数对象的调用 又想了想不太对劲 自己调用自己啥的也没有对三变量定义的过程啊 还是重复这样
# 搜了很多没发现什么雷同的 看书上解释是一种全局变量 但还是不理解 毕竟是第一次见
# 换了个方向搜索 仍然无果而终 最后用type()检测变量 的确是个变量 好吧 难受的心路历程
# 由于按顺序去看的函数没有先看下面的执行函数 发现 执行函数中确实有提前定义这几个变量 全局变量石锤
# 现在唯一的疑惑就是 为什么可以这样定义?有什么意义?
# 猜测:由于执行函数会调用多个函数来实现总的绘图 所以我们需要用函数名.变量 这样的形式来区分应用于哪个函数
# 当然这种变量也可以放在需求函数里面定义 但是由于此处变量需要inTree变量来计算值所以就干脆放在执行函数里面了
# 执行函数的参数恰好就是inTree (图个方便?)
# 巧妙的分析一波 哈哈啊哈哈
plotMidText(cntrPt, parentPt, nodeTxt) # 画连线标注
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 画结点 画线
secondDict = myTree[firstStr] # 解析下一个字典(根节点)/叶子
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 计算下一个结点的y
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 子树
plotTree(secondDict[key], cntrPt, str(key))
else: # 叶子节点
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# 递归完后需要回退到上层,绘制当前树根节点的其他分支节点。
执行绘制
"""
函数说明:
绘制决策树---执行
Parameters:
inTree:决策树
Return:
None 展示画布
"""
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[]) # 定义x,y轴为空 为后面不显示轴作准备
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 传入刚才定义参数 构造子图
# createPlot.ax1 = plt.subplot(111, frameon=False) # 有轴的
plotTree.totalW = float(getNumLeafs(inTree))
# tatalW:树的宽度初始化 = 叶子节点
plotTree.totalD = float(getTreeDepth(inTree))
# tatalD:树的深度 = 树高
plotTree.xOff = -0.5/plotTree.totalW # 为开始x位置为第一个表格左边的半个表格距离位置
plotTree.yOff = 1.0 # y位置1
# 使用两个全局变量plotTree.xOff、plotTree.yOff追踪已经绘制的节点位置
# 这部分代码直接去看很难理解 之后会有注解
plotTree(inTree, (0.5, 1.0), '') # 调用函数开始绘图 一开始标注为空 因为第一个就是根结点
plt.show()
执行查看
if __name__ == '__main__':
'''
print(getNumLeafs(retrieveTree(0))) # 3
print(getNumLeafs(retrieveTree(1))) # 4
print(getTreeDepth(retrieveTree(0))) # 2
print(getTreeDepth(retrieveTree(1))) # 3
'''
mytree = retrieveTree(1)
createPlot(mytree)
注释:
到这里代码部分就结束了,但是对于决策树的每一个结点位置的计算也是比较难理解的,为此我找了几篇博客进行学习,下面有链接,大家可以继续学习一下:
https://www.cnblogs.com/fantasy01/p/4595902.html
https://www.cnblogs.com/hithink/p/6245993.html
总结
本节内容的难点主要在于对结点的位置计算的理解,耗费了很多时间,我们在本次绘制中是从上往下绘制的,因此y在下减,结点从左至右,x也随着结点的扩建而向右移动,但在最后计算完毕后还是回到了初始值的位置以便于下一次的计算,还有一些细小的方面需要在代码中一遍一遍的串就会清楚一些,虽然算是完成了绘制,但是想要再完整的绘制一次还是会昏头,嗯。。。。。(又一次发现了自己的渣崽本质)