#3-5使用文本注解绘制树节点
decisionNode = dict(boxstyle ="sawtooth", fc ="0.8")#创建一个字典
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',xytext=centerPt,\
textcoords='axes fraction',va="center", ha="center",bbox=nodeType,arrowprops=arrow_args)
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()#清除当前窗口
#xticks是一个列表,其中的元素就是x轴上将显示的坐标,yticks是y轴上显示的坐标,
# 这里空列表则不显示坐标
axprops = dict(xticks = [], yticks = [])
#这里定义一个子图窗口,第一个参数xyz含义是,将框架划分为x行y列窗口,
# ax1代表其第z个窗口。frameon = False将隐藏坐标轴
createPlot.ax1 = plt.subplot(111, frameon = False)
#plotTree.totalW是决策树的叶子树,也代表宽度,plotTree.totalD是决策树的深度
plotTree.totalW = float(getNumleafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree,(0.5,1.0),'')
plt.show()
#3-6获取叶节点的数目和树的层数
def getNumleafs(myTree):#getNumleafs函数遍历整颗树,累计叶子节点的个数,并返回该数值。
numLeafs = 0#初始化叶节点个数
firstsides = list(myTree.keys())# 从myTree的所有节点获取第一个节点(根节点)
firstStr = firstsides[0]
secondDict = myTree[firstStr]
# 遍历根key的value(value包含根key包含的余下所有的子节点)
# 上一级的value包含下一级的key,因此通过递归,可以不断取到下一层的value
for key in secondDict.keys():
if type(secondDict[key]) == 'dict':#判断子节点是否是字典类型
numLeafs += getNumleafs(secondDict[key])
#如果是,说明该节点也是一个判断节点,需要递归调用getNumleafs函数
else:# 如果获取的vlaue不再是字典,说明已经是最后一个子节点,进行一次加1操作
numLeafs += 1
return numLeafs
# 树的层数与获取叶节点的步骤相似,区别在于
# 叶节点数每遍历一次,如果遍历到叶子节点,那么将计数加一,累计叶子节点的个数;
# 树层数的计数在递归的过程中,如果遍历到叶子节点,就会将计数值置为1,只保留max的计数。
# 将这一层的深度记为1
def getTreeDepth(myTree): # 获取树的层数
maxDepth = 0 # 初始化一个记录最大深度的变量
firstsides = list(myTree.keys())
firstStr = firstsides[0]#将字典转化为列表函数
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == 'dict':#判断子节点是否是字典类型
thisDepth = 1 + getTreeDepth(secondDict[key])
else:# 如果没有遍历到dict,只有一层
thisDepth = 1
# 每一个key对用的子节点串(每一条路径)都会有一个最大值,记录其中最大的那个
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers': {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
]
return listOfTrees[i]
#3-7plotTree函数
def plotMidText(cntrPt, parentPt, txtString):#子节点,父节点,文本标签信息
#函数计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumleafs(myTree)#计算叶节点数目
depth = getTreeDepth(myTree)#获取树的层数
firstsides = list(myTree.keys())
#计算树的宽和高
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt, parentPt,nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yoff = plotTree.yoff -1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) == 'dict': # 判断子节点是否是字典类型
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff)),cntrPt,leafNode)
plotMidText((plotTree.xOff, plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
本章难度比较大。具体先参考: