注解树:将树中的信息添加到决策树图中。
Note:绘制图形的x轴有效范围是0.0到1.0,y轴有效范围也是0.0到1.0。
def getNumLeafs(myTree): #获取叶节点树
numLeafs = 0
# firstStr = myTree.keys()[0] #mytree的第一个特征值 python2写法
first = list(myTree.keys())
firstStr = first[0]
secondDict = myTree[firstStr] #mytree经过第一个特征值分类后的字典
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #判断是否分类过的数据为字典
numLeafs += getNumLeafs(secondDict[key]) #子节点数量+1 递归,判断一共多少个子节点
else: numLeafs += 1
return numLeafs
def getTreeDepth(myTree): #获取树深度
maxDepth = 0
# firstStr = myTree.keys()[0] #mytree的第一个特征值 python2写法
first = list(myTree.keys())
firstStr = first[0]
secondDict = myTree[firstStr] # mytree经过第一个特征值分类后的字典
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 判断是否分类过的数据为字典
thisDepth = 1 + getTreeDepth(secondDict[key]) #数的深度+1
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth #最大深度
#二叉树的深度是指有多少层, 而二元决策树的深度是指经过多少层计算
#测试的树数据
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1:'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1:'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString): #在坐标点cntrPt和parentPt连接线上的中点,显示文本txtString
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #x坐标
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] #y坐标
createPlot.ax1.text(xMid, yMid, txtString) #在(xMid, yMid)处显示txtString
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #获得当前树叶节点个数
depth = getTreeDepth(myTree) #获得当前树深度
# firstStr = myTree.keys()[0] #mytree的第一个特征值 python2写法
first = list(myTree.keys())
firstStr = first[0] #第一个分类的特征值,即根节点
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# plotTree.xOff和plotTree.yOff是用来追踪已经绘制的节点位置,plotTree.totalW为这个数的宽度,叶节点数
plotMidText(cntrPt, parentPt, nodeTxt) #显示节点
plotNode(firstStr, cntrPt, parentPt, decisionNode) #firstStr为需要显示的文本,cntrPt为文本的中心点,
# parentPt为箭头指向文本的起始点,decisionNode为文本属性
secondDict = myTree[firstStr] #子树
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #totalD是这个数的深度,深度移下一层
for key in secondDict.keys():
if type(secondDict[key]).__name__=="dict":
plotTree(secondDict[key], cntrPt, str(key)) #递归显示子树
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #x坐标平移一个单位
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #画叶节点
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) #显示箭头文本
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #下移一层深度
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white') #创建一个画布,背景为白色
fig.clf() #画布清空
axprops = dict(xticks=[], yticks=[]) #定义横纵坐标轴,无内容
#ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
# createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #frameon表示是否绘制坐标轴矩形,无坐标轴
createPlot.ax1 = plt.subplot(111, frameon = False) #frameon表示是否绘制坐标轴矩形
plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #树的深度
plotTree.xOff = -0.5/plotTree.totalW #x轴起始值,之前一开始定义了1.0+,有一个偏差
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
myTree = retrieveTree(0)
createPlot(myTree)
此树叶节点为3,深度为2。