使用Matplotlib注解绘制树形图
annotation是注解工具,注解功能可以对文字着色,并提供多种形状以供选择,还可以反转箭头。创建名为treePlotter.py的新文件。
使用文本注解绘制树节点:
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
#绘制带箭头的注释
#nodeTxt:节点文本
#centerPt:文本框的中心位置
#parentPt:箭头线的起点
#nodeType:文本框类型
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
#xycoords和textcoords分别代表坐标系,va和ha分别指文字在框中横向和纵向的位置
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
xytext=centerPt,textcoords='axes fraction',\
va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)
def createPlot():
#新建画布并设置画布的背景
fig = plt.figure(1,facecolor='white')
#清空
fig.clf()
#frameon是否绘制坐标轴的矩形
createPlot.ax1 = plt.subplot(111,frameon=False)
plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
得到的结果如下:
此外绘制图形必须知道叶子节点个数,以便确定x的长度,还需要知道树有多少层,以便确定y'轴高度。定义两个函数getNumLeafs()和getTreeDepth():
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
源码较为简单,不做深度解析。值得注意的是获取键之后不能直接用索引来获取某一个键,需要先用list函数序列化,否则会报错。
为了节省从数据中创建树字典的时间,使用函数retrieveTree输出预先存储的树信息。代码如下:
#构建测试数据集
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
以下为测试过程:
>>>myTree = treePlotter.retrieveTree(0)
>>>treePlotter.getNumLeafs(myTree)
3
接下来组合一起生成完整的决策树:
#根据字典绘制图形
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]+cntrPt[0])/2.0
yMid = (parentPt[1]+cntrPt[1])/2.0
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
plotTree.y0ff = plotTree.y0ff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -0.5/plotTree.totalW
plotTree.y0ff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
测试的结果如下:
以上代码主要是定位节点的坐标,操作比较繁琐。具体原理请参考定位详解。关于 annotate函数的详解请参考annotate。