笔者最近在学习《机器学习实战》,对这本书的表示由衷的喜爱,原因如下:1.系统讲解机器学习方法,2.将机器学习中的方法讲得简单易懂,3.一步一步教会了笔者如何构建这些方法的程序。对此,笔者再次表示对本书的喜爱和其作者及其译者的感谢。当然,笔者在学习这本书并非一帆风顺,这不,卡在了第三章决策树的绘图部分好些天,趁周末,赶紧做一做。修改了一些地方,方将代码跑通,下面给出代码(运行环境python3):
# encdoing:utf-8
import matplotlib.pyplot as plt
# 获取叶子节点
def getNumLeafs(intree):
numLeafs = 0
a = intree.keys()
firstStr = [each for each in a][0]
secondDict = intree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 获取树的层数
def getTreeDepth(intree):
maxDepth,thisDepth = 0,0
a = intree.keys()
firstStr = [each for each in a][0]
secondDict = intree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getNumLeafs(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 生成符合树结构的dict
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
# 在父子节点间填充文本信息
def plotMidText(cntrpt,parentPt,txtString):
xmid = (parentPt[0] - cntrpt[0])/2.0 + cntrpt[0]
ymid = (parentPt[1] - cntrpt[1])/2.0 + cntrpt[1]
creatPlot_ax1.text(xmid,ymid,txtString)
# 绘制树结构函数
def plotTree(intree,parentPt,nodeTxt,plotTree_yOff = 1.0):
plotTree_totalW = float(getNumLeafs(intree))
plotTree_totalD = float(getTreeDepth(intree))
plotTree_xOff = -0.5 / plotTree_totalW
numLeafs = getNumLeafs(intree)
depth = getTreeDepth(intree)
a = intree.keys()
firstStr = [each for each in a][0]
cntrPt = (plotTree_xOff + (1.0 + float(numLeafs))/2.0/plotTree_totalW,plotTree_yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = intree[firstStr]
plotTree_yOff = plotTree_yOff - 1.0/plotTree_totalW
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key),plotTree_yOff)
else:
plotTree_xOff = plotTree_xOff + 1.0/plotTree_totalW
plotNode(secondDict[key],(plotTree_xOff,plotTree_yOff),cntrPt,leafNode)
plotMidText((plotTree_xOff,plotTree_yOff),cntrPt,str(key))
plotTree_yOff = plotTree_yOff + 1.0/plotTree_totalD
# 绘图
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
creatPlot_ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=dict(boxstyle='round4'),arrowprops = dict(arrowstyle = '<-'))
if __name__=='__main__':
# 定义文本框
decisionNode = dict(boxstyple='swatooth',fc='0.8')
leafNode = dict(boxstyle='round4',fc=0.8)
mytree = retrieveTree(1) # 取出符合决策树结构的数据,可自定义
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
creatPlot_ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree_yOff = 1.0 # 设置默认参数
plotTree(mytree, (0.5, 1.0), '') # 调用绘制树结构图函数
plt.show() # 图片展示
请需要的读者,将此代码粘贴到文件内(注意粘贴后的代码格式)。下面作几点说明:
1)代码标注红色部分为修改部分;
2)将《机器学习实战》书中P47页的createplot函数部分,分别家在其他函数中去,此部分用绿色标记;
3)代码的详细解释见书本,若有疑问请留言。
此处附上代码运行后结果图:
并附上代码的githup链接:
https://gitee.com/someone317/backpropagation_algorithm_test/blob/master/drawTree.py