机器学习实战——3.2 在Python中使用Matplotlib注释绘制树形图

目录

1. Matplotlib注释

1.1 使用文本注释绘制树节点

2. 构造注解图

2.1 获取叶节点的数目和树的层数

 2.2 plottree函数


1. Matplotlib注释

        Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图形上添加文本注解。下面将使用注解功能绘制树形图,它可以对文字着色并提供多种形状,同时我们也可以反转箭头的符号,使其指向文本框,而不是数据点。

1.1 使用文本注释绘制树节点

import matplotlib.pyplot as plt
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']

"定义文本框和箭头格式"
decisionnode = dict(boxstyle = "sawtooth", fc = "0.8")
leafnode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

"绘制带箭头的注解"
def plotnode(nodetxt, centerpt, parentpt, nodetype):
    createplot.ax1.annotate(nodetxt,
                           xy = parentpt,
                           xycoords = 'axes fraction',
                           xytext = centerpt,
                           textcoords = 'axes fraction',
                           va = 'center',
                           ha = 'center',
                           bbox = nodetype,
                           arrowprops = arrow_args)

def createplot():
    fig = plt.figure(1, facecolor='white')  # 创建新图形
    fig.clf()  # 清空绘图区
    createplot.ax1 = plt.subplot(111, frameon = False)
    plotnode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionnode)
    plotnode('叶节点', (0.8, 0.1), (0.3, 0.8), leafnode)
    plt.show()

查看树节点图:

createplot()

输出结果:

2. 构造注解图

2.1 获取叶节点的数目和树的层数

def getnumleafs(mytree):
    numleafs = 0
    firststr = list(mytree.keys())[0]
    seconddict = mytree[firststr]
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            numleafs += getnumleafs(seconddict[key])
        else:numleafs +=1
    return numleafs
def gettreedepth(mytree):
    maxdepth = 0
    firststr = list(mytree.keys())[0]
    seconddict = mytree[firststr]
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            thisdepth = 1 + getnumleafs(seconddict[key])
        else:thisdepth = 1
        if thisdepth > maxdepth:
            maxdepth = thisdepth
    return maxdepth

注意:第3和12行代码在书中为:firststr = mytree.keys()[0],这个是Python2中的写法,python2中形如myTree.keys()[0]这样的写法是没有问题的,因为myTree.keys()返回的是一个list;而在python3中myTree.key()返回的则是dick_keys类型,故而出错。

解决方法:将mytree.keys()[0]  改为  list(mytree.keys())[0]

        为节省时间, 函数retrievetree输出预先存储的书信息,避免每次测试代码时都要从数据中创建书的麻烦:

def retiretree(i):
    listoftrees = [{'no surfacing': {0 : 'no', 1 : {'flippers' : {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0 : 'no', 1 : {'flippers' : {0:{'head':{0:'no', 1:'yes'}}, 1 : 'no'}}}}]
    return listoftrees[i]

查看树结构、叶节点数目、树的层数:

print("retiretree(1):", retiretree(1))
print("retiretree(0):", retiretree(0))
mytree = retiretree(0)
print("叶节点个数:", getnumleats(mytree))
print("树的层数:", gettreedepth(mytree))

输出结果:

 2.2 plottree函数

def plotmidtext(cntrpt, parentpt, txtstring):
    xmid = (parentpt[0]-cntrpt[0])/2.0 + cntrpt[0]
    ymid = (parentpt[1]-cntrpt[1])/2.0 + cntrpt[1]
    createplot.ax1.text(xmid, ymid, txtstring, va="center", ha="center", rotation=30)

def plottree(mytree, parentpt, nodetxt):
    numleafs = getnumleafs(mytree)
    depth = gettreedepth(mytree)
    firststr = list(mytree.keys())[0]
    cntrpt = (plottree.xoff + (1.0 + float(numleafs))/2.0/plottree.totalw, plottree.yoff)
    plotmidtext(cntrpt, parentpt, nodetxt)
    plotnode(firststr, cntrpt, parentpt, decisionnode)
    seconddict = mytree[firststr]
    plottree.yoff = plottree.yoff - 1.0/plottree.totald
    for key in seconddict.keys():
        if type(seconddict[key]).__name__=='dict':
            plottree(seconddict[key],cntrpt,str(key))
        else:
            plottree.xoff = plottree.xoff + 1.0/plottree.totalw
            plotnode(seconddict[key], (plottree.xoff, plottree.yoff), cntrpt, leafnode)
            plotmidtext((plottree.xoff, plottree.yoff), cntrpt, str(key))
    plottree.yoff = plottree.yoff + 1.0/plottree.totald

def createplot(intree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createplot.ax1 = plt.subplot(111, frameon=False, **axprops)
    "使用下两行代码计算树节点的摆放位置,将树绘制在水平方向和垂直方向的中心位置"
    plottree.totalw = float(getnumleafs(intree))  # 树的宽度
    plottree.totald = float(gettreedepth(intree))  # 树的高度
    plottree.xoff = -0.5/plottree.totalw; plottree.yoff = 1.0  # 追踪已经绘制的节点位置,以及放置下一个节点的恰当位置
    plottree(intree, (0.5,1.0), '')
    plt.show()

查看树结构:

mytree = retiretree(0)
createplot(mytree)

输出结果:

 变更字典,重新绘制树形图:

mytree['no surfacing'][3] = 'maybe'
createplot(mytree)

输出结果:

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值