目录
1. Matplotlib注释
Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图形上添加文本注解。下面将使用注解功能绘制树形图,它可以对文字着色并提供多种形状,同时我们也可以反转箭头的符号,使其指向文本框,而不是数据点。
1.1 使用文本注释绘制树节点
import matplotlib.pyplot as plt
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
"定义文本框和箭头格式"
decisionnode = dict(boxstyle = "sawtooth", fc = "0.8")
leafnode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
"绘制带箭头的注解"
def plotnode(nodetxt, centerpt, parentpt, nodetype):
createplot.ax1.annotate(nodetxt,
xy = parentpt,
xycoords = 'axes fraction',
xytext = centerpt,
textcoords = 'axes fraction',
va = 'center',
ha = 'center',
bbox = nodetype,
arrowprops = arrow_args)
def createplot():
fig = plt.figure(1, facecolor='white') # 创建新图形
fig.clf() # 清空绘图区
createplot.ax1 = plt.subplot(111, frameon = False)
plotnode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionnode)
plotnode('叶节点', (0.8, 0.1), (0.3, 0.8), leafnode)
plt.show()
查看树节点图:
createplot()
输出结果:
2. 构造注解图
2.1 获取叶节点的数目和树的层数
def getnumleafs(mytree):
numleafs = 0
firststr = list(mytree.keys())[0]
seconddict = mytree[firststr]
for key in seconddict.keys():
if type(seconddict[key]).__name__ == 'dict':
numleafs += getnumleafs(seconddict[key])
else:numleafs +=1
return numleafs
def gettreedepth(mytree):
maxdepth = 0
firststr = list(mytree.keys())[0]
seconddict = mytree[firststr]
for key in seconddict.keys():
if type(seconddict[key]).__name__ == 'dict':
thisdepth = 1 + getnumleafs(seconddict[key])
else:thisdepth = 1
if thisdepth > maxdepth:
maxdepth = thisdepth
return maxdepth
注意:第3和12行代码在书中为:firststr = mytree.keys()[0],这个是Python2中的写法,python2中形如myTree.keys()[0]这样的写法是没有问题的,因为myTree.keys()返回的是一个list;而在python3中myTree.key()返回的则是dick_keys类型,故而出错。
解决方法:将mytree.keys()[0] 改为 list(mytree.keys())[0]
为节省时间, 函数retrievetree输出预先存储的书信息,避免每次测试代码时都要从数据中创建书的麻烦:
def retiretree(i):
listoftrees = [{'no surfacing': {0 : 'no', 1 : {'flippers' : {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0 : 'no', 1 : {'flippers' : {0:{'head':{0:'no', 1:'yes'}}, 1 : 'no'}}}}]
return listoftrees[i]
查看树结构、叶节点数目、树的层数:
print("retiretree(1):", retiretree(1))
print("retiretree(0):", retiretree(0))
mytree = retiretree(0)
print("叶节点个数:", getnumleats(mytree))
print("树的层数:", gettreedepth(mytree))
输出结果:
2.2 plottree函数
def plotmidtext(cntrpt, parentpt, txtstring):
xmid = (parentpt[0]-cntrpt[0])/2.0 + cntrpt[0]
ymid = (parentpt[1]-cntrpt[1])/2.0 + cntrpt[1]
createplot.ax1.text(xmid, ymid, txtstring, va="center", ha="center", rotation=30)
def plottree(mytree, parentpt, nodetxt):
numleafs = getnumleafs(mytree)
depth = gettreedepth(mytree)
firststr = list(mytree.keys())[0]
cntrpt = (plottree.xoff + (1.0 + float(numleafs))/2.0/plottree.totalw, plottree.yoff)
plotmidtext(cntrpt, parentpt, nodetxt)
plotnode(firststr, cntrpt, parentpt, decisionnode)
seconddict = mytree[firststr]
plottree.yoff = plottree.yoff - 1.0/plottree.totald
for key in seconddict.keys():
if type(seconddict[key]).__name__=='dict':
plottree(seconddict[key],cntrpt,str(key))
else:
plottree.xoff = plottree.xoff + 1.0/plottree.totalw
plotnode(seconddict[key], (plottree.xoff, plottree.yoff), cntrpt, leafnode)
plotmidtext((plottree.xoff, plottree.yoff), cntrpt, str(key))
plottree.yoff = plottree.yoff + 1.0/plottree.totald
def createplot(intree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createplot.ax1 = plt.subplot(111, frameon=False, **axprops)
"使用下两行代码计算树节点的摆放位置,将树绘制在水平方向和垂直方向的中心位置"
plottree.totalw = float(getnumleafs(intree)) # 树的宽度
plottree.totald = float(gettreedepth(intree)) # 树的高度
plottree.xoff = -0.5/plottree.totalw; plottree.yoff = 1.0 # 追踪已经绘制的节点位置,以及放置下一个节点的恰当位置
plottree(intree, (0.5,1.0), '')
plt.show()
查看树结构:
mytree = retiretree(0)
createplot(mytree)
输出结果:
变更字典,重新绘制树形图:
mytree['no surfacing'][3] = 'maybe'
createplot(mytree)
输出结果: