Matplotlib注解(主要解决《机器学习实战》第三章绘树图部分问题,运行环境:python3)

    笔者最近在学习《机器学习实战》,对这本书的表示由衷的喜爱,原因如下:1.系统讲解机器学习方法,2.将机器学习中的方法讲得简单易懂,3.一步一步教会了笔者如何构建这些方法的程序。对此,笔者再次表示对本书的喜爱和其作者及其译者的感谢。当然,笔者在学习这本书并非一帆风顺,这不,卡在了第三章决策树的绘图部分好些天,趁周末,赶紧做一做。修改了一些地方,方将代码跑通,下面给出代码(运行环境python3):

# encdoing:utf-8
import matplotlib.pyplot as plt

# 获取叶子节点
def getNumLeafs(intree):
    numLeafs = 0
    a = intree.keys()
    firstStr = [each for each in a][0]
    secondDict = intree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

# 获取树的层数
def getTreeDepth(intree):
    maxDepth,thisDepth = 0,0
    a = intree.keys()
    firstStr = [each for each in a][0]
    secondDict = intree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getNumLeafs(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

# 生成符合树结构的dict
def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

# 在父子节点间填充文本信息
def plotMidText(cntrpt,parentPt,txtString):
    xmid = (parentPt[0] - cntrpt[0])/2.0 + cntrpt[0]
    ymid = (parentPt[1] - cntrpt[1])/2.0 + cntrpt[1]
    creatPlot_ax1.text(xmid,ymid,txtString)

# 绘制树结构函数
def plotTree(intree,parentPt,nodeTxt,plotTree_yOff = 1.0):
    plotTree_totalW = float(getNumLeafs(intree))
    plotTree_totalD = float(getTreeDepth(intree))
    plotTree_xOff = -0.5 / plotTree_totalW
    numLeafs = getNumLeafs(intree)
    depth = getTreeDepth(intree)
    a = intree.keys()
    firstStr = [each for each in a][0]
    cntrPt = (plotTree_xOff + (1.0 + float(numLeafs))/2.0/plotTree_totalW,plotTree_yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = intree[firstStr]
    plotTree_yOff = plotTree_yOff - 1.0/plotTree_totalW
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key),plotTree_yOff)
        else:
            plotTree_xOff = plotTree_xOff + 1.0/plotTree_totalW
            plotNode(secondDict[key],(plotTree_xOff,plotTree_yOff),cntrPt,leafNode)
            plotMidText((plotTree_xOff,plotTree_yOff),cntrPt,str(key))
    plotTree_yOff = plotTree_yOff + 1.0/plotTree_totalD

# 绘图
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    creatPlot_ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=dict(boxstyle='round4'),arrowprops = dict(arrowstyle = '<-'))

if __name__=='__main__':
    # 定义文本框
    decisionNode = dict(boxstyple='swatooth',fc='0.8')
    leafNode = dict(boxstyle='round4',fc=0.8)
    mytree = retrieveTree(1) # 取出符合决策树结构的数据,可自定义
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    creatPlot_ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree_yOff = 1.0 # 设置默认参数
    plotTree(mytree, (0.5, 1.0), '') # 调用绘制树结构图函数
    plt.show() # 图片展示

请需要的读者,将此代码粘贴到文件内(注意粘贴后的代码格式)。下面作几点说明

1)代码标注红色部分为修改部分;

2)将《机器学习实战》书中P47页的createplot函数部分,分别家在其他函数中去,此部分用绿色标记;

3)代码的详细解释见书本,若有疑问请留言。

此处附上代码运行后结果图:

                                

并附上代码的githup链接:

https://gitee.com/someone317/backpropagation_algorithm_test/blob/master/drawTree.py

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值