画树完成之后,发现数据形式的树很难直观的看出其中的分类决策,尤其当参与分类的特征较多的情况下。个人工作是做信贷风控决策,碰到的特征变量数不胜数,简直头疼,书上选用matplotlib库的注解功能尝试绘制属性图,这个库之前入门python的时候就了解过,参数实在是太多,比较烦,用来试试。
以文本注解的形式绘制树的节点:
from matplotlib import pyplot as plt
# 节点和箭头的样式确定一下
decisionNode = dict(boxstyle='sawtooth',fc='0.8')
leafNode = dict(boxstyle='round4',fc='0.8')
arrow_args = dict(arrowstyle='<-')
# 绘制节点函数
def plotNode(nodetext,cntrPt,parentPt,nodeType):
# ax1为createPlot属性,调用函数annotate()绘制节点
createPlot.ax1.annotate(nodetext,xy=parentPt,xycoords='axes fraction',
xytext=cntrPt,textcoords='axes fraction',
va='center',ha='center',bbox=nodeType,
arrowprops=arrow_args)
# 绘制一下节点试试
def createPlot():
fig = plt.figure(1,facecolor='white')
fig.clf() # 清空绘图区
# 全局变量createPlot.ax1定义绘图区域(变量默认均为全局有效)
createPlot.ax1 = plt.subplot(111,frameon=False) # 去除表框轴
plotNode('决策节点',(0.5,0.8),(0.1,0.5),decisionNode)
plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
然后递归函数计算叶子结点数量以及树的层数:
# 计算叶子结点数量
def calcLeafs(myTree):
numofLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numofLeafs += calcLeafs(secondDict[key])
else:
numofLeafs += 1
return numofLeafs
# 计算树的层数(经过几次特征划分)
def calcDepth(myTree):
numofDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + calcDepth(secondDict[key])
else:
thsiDepth = 1
if thisDepth > numofDepth:
numofDepth = thisDepth
return numofDepth
为最后绘制工作增加绘制父子节点间文本的部分:
def plotMidtext(cntrPt,parentPt,textstring):
xmid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
ymid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xmid,ymid,textstring)
createPlot才是真正的大哥,我们的主函数,其他都是小弟,其中的还有一个小弟是负责画树的,把前面的小弟集合一下:
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = calcLeafs(myTree)
depth = calcDepth(myTree)
firstStr = list(myTree.keys())[0]
SecondDict = myTree[firstStr]
# 当前绘制的节点坐标(x轴按照该层叶子结点数量均分,当前坐标为中点)
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs) / 2.0 / plotTree.totalW,
plotTree.yOff)
# 绘制第一层,当前分类的节点
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
# y轴按照层数depth均分,每层绘制完则递减
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
# 校正?
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
有一个看了很久看不懂的,当前节点坐标cntrPt,书中根据该层叶子节点数量均分X轴,当前结点位置位于中间float(numLeafs)/2.0/plotTree.totalW,在初始化plotTree.xOff值的时候取-0.5/plotTree.totalW(叶子结点位于x中间),即整体向左挪动了0.5个x,所以需要将以上向左偏移的0.5个x补回来,叶子结点x坐标则按照每次一个x的距离增加。图是按照比例绘制的。
接着用createPlot绘制展示:
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks[],yticks[]) # 刻度轴设置为空
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(calcLeafs(inTree))
plotTree.totalD = float(calcDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0 # 初始值为1.0,自上而下绘制且y递减
plotTree(inTree,(0.5,1.0),'')
plt.show()