上一节实现了决策树,但只是使用包含树结构信息的嵌套字典来实现,其表示形式较难理解,显然,绘制直观的二叉树图是十分必要的。Python没有提供自带的绘制树工具,需要自己编写函数,结合Matplotlib库创建自己的树形图。这一部分的代码多而复杂,涉及二维坐标运算;书里的代码虽然可用,但函数和各种变量非常多,感觉非常凌乱,同时大量使用递归,因此只能反复研究,反反复复用了一天多时间,才差不多搞懂,因此需要备注一下。
一.绘制属性图
这里使用Matplotlib的注解工具annotations实现决策树绘制的各种细节,包括生成节点处的文本框、添加文本注释、提供对文字着色等等。在画一整颗树之前,最好先掌握单个树节点的绘制。一个简单实例如下:
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 04 01:15:01 2015
@author: Herbert
"""
import matplotlib.pyplot as plt
nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodes = dict(boxstyle = "round4", fc = "0.8")
line = dict(arrowstyle = "<-")
def plotNode(nodeName, targetPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords = \
'axes fraction', xytext = targetPt, \
textcoords = 'axes fraction', va = \
"center", ha = "center", bbox = nodeType, \
arrowprops = line)
def createPlot():
fig = plt.figure(1, facecolor = 'white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False)
plotNode('nonLeafNode', (0.2, 0.1), (0.4, 0.8), nonLeafNodes)
plotNode('LeafNode', (0.8, 0.1), (0.6, 0.8), leafNodes)
plt.show()
createPlot()
输出结果:
该实例中,plotNode()
函数用于绘制箭头和节点,该函数每调用一次,将绘制一个箭头和一个节点。后面对于该函数有比较详细的解释。createPlot()
函数创建了输出图像的对话框并对齐进行一些简单的设置,同时调用了两次plotNode()
,生成一对节点和指向节点的箭头。
绘制整颗树
这部分的函数和变量较多,为方便日后扩展功能,需要给出必要的标注:
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 04 01:15:01 2015
@author: Herbert
"""
import matplotlib.pyplot as plt
# 部分代码是对绘制图形的一些定义,主要定义了文本框和剪头的格式
nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodes = dict(boxstyle = "round4", fc = "0.8")
line = dict(arrowstyle = "<-")
# 使用递归计算树的叶子节点数目
def getLeafNum(tree):
num = 0
firstKey = tree.keys()[0]
secondDict = tree[firstKey]