import matplotlib.pyplot as plt
#解决中文字和坐标问题
plt.rcParams['font.sans-serif'] = ['Simhei']
plt.rcParams['axes.unicode_minus'] = False
#定义文本框和箭头格式,必须以字典方式定义,后面程序要用
decisionNode = dict( boxstyle = 'sawtooth', fc = '0.8' ) #boxstyle为文本框类型,sawtooth为锯齿形
leafNode = dict( boxstyle = 'round4', fc = '0.8' ) #round4为长方圆形,fc是边框线粗细
arrow_args = dict( arrowstyle = '<-' ) #arrowstyle为箭头的样式
# 文本 文本位置 起点 框的样式和粗细
def plotNode(nodeTxt, centerpt, parentPt, nodeType):
fig = plt.figure(1) #括号中不写1会默认创建1,2,3等递增不同的图,本程序就会出现3个图
ax1 = fig.add_subplot(111)
ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerpt,\
textcoords='axes fraction', va='center', ha='center', bbox=nodeType,\
arrowprops=arrow_args)
#最后一个位置定义的箭头样式
#也可以:
#plt.figure(1)
#plt.subplot(111)
#plt.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerpt,\
# textcoords='axes fraction', va='center', ha='center', bbox=nodeType,\
# arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor='white') #定义一个画布,背景为白色,不定义颜色默认白色
fig.clf() #清空画布
ax1 = fig.add_subplot(111, frameon=False) #frameon表示是否绘制坐标轴矩形
plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
获取叶节点个数和树的深度:
def getNumLeafs(myTree): #获取叶节点数目
numLeafs = 0
firstStr = list( myTree.keys() )[0] #把树转换成关键字列表,此时列表中只有一个关键字,因为是第一个分支点,python3中myTree.keys()为dict_keys型,不是list,必须要用list()将其转化为列表型才可以
secondDict = myTree[firstStr] #获取关键字(第一个问题)下的内容,至少有一个回答和一个结果,所以内容至少是{0:1}
for key in secondDict.keys(): #遍寻第一个问题的所有回答,即第一个关键字下的字典的关键字
if type(secondDict[key]).__name__ == 'dict': #判断下一级是不是还是字典, .__name__作用是将类型名称变为str
numLeafs += getNumLeafs(secondDict[key]) #叶节点的数目等于所有最后一级的总数目
else: #比如第一个问题有2个分支,1个分支到底了+1,另一个分支又分出2个分支,2个分支都到底了,+2,一共是3
numLeafs += 1 #按程序步骤是,第一个关键字不符合if,+1,第二个关键字进入getNumLeafs(secondDict[key]),两个分支都不符合if,return2
return numLeafs #即getNumLeafs(secondDict[key])是2,最后结果是3
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list( myTree.keys() )[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth( secondDict[key] )
else:
thisDepth = 1
if thisDepth > maxDepth: #key寻遍所有关键字,第一个关键字可能层数只有1,但是第二个可能是2层
maxDepth = thisDepth #加入一个比较,取层数最多的那个就是树的层数
return maxDepth
知识要点:
import matplotlib.pyplot as plt
plt.figure(1)
plt.subplot(111)
plt.annotate('zheli', xy=(1,1), xytext=(5,5), bbox=dict( boxstyle = 'sawtooth', fc = '0.8' ), arrowprops=dict( arrowstyle = '<-' ))
plt.axis([0,10,0,10])
plt.show()
显示内容是‘zheli’,起点xy为(1,1),文本位置xytext为(5,5),文本框的样式bbox和箭头样式arrowprops都要求是字典,xy=,xytext=这些必须要写,没有位置
如果把箭头变成:arrowprops=dict( arrowstyle = '->' )),则图形是:
plt.text(2,3,'111') #在(2,3)点的位置放上文本‘111’