Machine_Learning_2019_Task 9 绘制树图形
要求
利用 Python 结合 Matplotlib 绘制树图形
- Matplotlib 注释
- 构造注解树
导入matplotlib
import matplotlib.pyplot as plt
绘制属性图,定义文本框和箭头格式以及树结点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
构造注解树,利用 Python 字典存储树,获取叶子结点
def getNumLeafs(myTree):
# 初始化结点数
numLeafs=0
firstSides = list(myTree.keys())
# 找到输入的第一个元素,第一个关键词为划分数据集类别的标签
firstStr = firstSides[0]
secondDict = myTree[firstStr]
# 测试数据是否为字典形式
for key in secondDict.keys():
# type判断子结点是否为字典类型
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
#若子节点也为字典,则也是判断结点,需要递归获取num
else: numLeafs+=1
# 返回整棵树的结点数
return numLeafs
计算树的深度
def getTreeDepth(myTree):
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0]
# 获取划分类别的标签
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
绘制带箭头的注解 ,执行实际的绘图功能
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
绘制属树形图 递归函数
def plotTree(myTree, parentPt, nodeTxt):
# 计算树的宽度 totalW
numLeafs = getNumLeafs(myTree)
# 计算树的高度 存储在totalD
depth = getTreeDepth(myTree)
firstSides = list(myTree.keys())
# firstStr = myTree.keys()[0] 续作修改
# 找到输入的第一个元素
firstStr = firstSides[0]
# 按照叶子结点个数划分x轴
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 标注子结点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# y方向上的摆放位置,自上而下绘制,递减y值
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
# 判断是否为字典,不是则为叶子结点
if type(secondDict[key]).__name__=='dict':
# 递归查找
plotTree(secondDict[key],cntrPt,str(key))
# 到达叶子结点
else:
# x方向计算结点坐标
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# 添加文本信息
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# 下次重新调用时恢复y
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
创建一个新的图形并清空绘图
# 主函数
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
# 在绘图区上绘制两个代表不同类型的树节点
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
输出预先存储的树信息,避免每次测试都需要重新创建树
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers':
{0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
Visualization 可视化结果