网上的版本好像好久都没更新了treePlotter是没有人用了么。今天学习的时候发现有些地方已经改了,我改的是在python 3.6 上的运行版本,需要导入matplotlib.pyplot
import matplotlib.pyplot as plt
# 定义决策树决策结果属性
descisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
# myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# # createPlot.ax1为全局变量,绘制图像句柄
# # frameon表示是否绘制坐标轴矩形
# createPlot.ax1 = plt.subplot(111, frameon=False)
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), descisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
# 这个是用来测试的
# -----------分割线-------------
# 获取树的叶子数和树的深度
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0] # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# ---------分割线-------------
# 制图
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = {'xticks': None, 'yticks': None}
createPlot.ax1 = plt.subplot(111, frameon=False)
plotTree.totalW = float(getNumLeafs(inTree)) # 全局变量宽度 = 叶子数目
plotTree.totalD = float(getTreeDepth(inTree)) # 全局变量高度 = 深度
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
# cntrPt文本中心点, parentPt指向文本中心的点
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, descisionNode)
seconDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in seconDict.keys():
if type(seconDict[key]).__name__ == 'dict':
plotTree(seconDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va='center', ha='center', rotation=30)
# createPlot(myTree)
这个treePlotter导入了就可以把原来得到的决策树模型导入啦,而且要注意是以字典形式导入,所以保存和导入文件的时候最好用json。
发布5分钟之后,突然发现已经有人改过了,那就只算是个学习笔记吧 - -