《MachineLearningInAction》之绘制决策树

《MachineLearningInAction》(Peter Harrington)中的代码有点小问题,我重写了全书所有代码,分享于此。


Block Ⅰ

import matplotlib.pyplot as plt #用于调用绘图
import matplotlib #用于调用rcParams属性,设置绘图窗口风格

Block Ⅱ

定义全局变量
decisionNode #决策节边框样式
leafNode #叶子节点边框样式
arrow_args #箭头样式


Block Ⅲ

定义函数

retrieveTree #为简化问题及做函数测试,手动生成大小不一的Tree
getNumLeafs #获取Tree的叶子数
getTreeDepth #获取Tree的深度,即decisionNode的个数
plotNode #绘制节点,通过nodeType参数区分decisionNode及leafNode
plotMidText #annotate每一个dict的key
plotTree #迭代绘制决策树

Block Ⅴ

测试代码

__name__ == '__main__'时,即作为主模块调用时执行


效果图



关键思路:

1、迭代生成整棵树,代码测试时从只有一个decisionNode开始(通过调用retrieveTree(0)获得)。

2、通过参数传递plot axis来在同一个轴上绘图。Peter通过在实时调用时给plotTree函数增加axis属性达到同样效果,稍显复杂。

3、整张图绘制在(0,0),(1,1)围成的矩形区域内绘制,第一个decisionNode中心位于(0.5,1),通过decisionNode及LeafNode的数目控制纵向及横向间距,此二者皆为定值。这一点不知道是否与Peter的思路一致,因为他的代码太令我眼花缭乱,没看,我全部重写的。

4、plotTree先绘制decisionNode及其于parentNode之间的箭头,特别地当中心坐标与父节点坐标相等时,系统函数不绘制箭头。


treePlotter源码如下:

# -*- coding: utf-8 -*-
"""
treePlotter.py
~~~~~~~~~~

A module with functions to plot decision tree.

Created on Thu Mar 23 17:26:57 2017

Run on Python 3.6

@author: Luo Shaozhuo

refer to 'MachineLearninginAction'

"""
#==============================================================================
# import
#==============================================================================
import matplotlib.pyplot as plt
import matplotlib


#==============================================================================
# Global variables
#==============================================================================
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-") 


#==============================================================================
# functions
#==============================================================================
def retrieveTree(i=0):
    """
    return a predefined tree
    ~~~~~~~~~~
    i: must be 0 or 1. 1 for a taller tree
    ~~~~~~~~~~
    dictTree
    """
    listOfTrees =[{'no surfacing': {0: 'no', 1: 'yes'}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]


def getNumLeafs(dictTree):
    """
    return the number of leafs
    ~~~~~~~~~~
    dictTree: a dictonary dipicting a decidion tree
    ~~~~~~~~~~
    nNumLeaf: number of leafs
    """
    nNumLeaf = 0
    for key in dictTree.keys():
        if type(dictTree[key]) == dict:
            nNumLeaf += getNumLeafs(dictTree[key])
        else:   nNumLeaf +=1
    return nNumLeaf


def getTreeDepth(dictTree):
    """
    return the tree depth
    ~~~~~~~~~~
    dictTree: a dictonary dipicting a decidion tree
    ~~~~~~~~~~
    nMaxDepth: tree depth
    """
    nMaxDepth = 0
    keys = list(dictTree.keys())[0]
    dictTrunk = dictTree[keys]
    for key in dictTrunk.keys():
        if type(dictTrunk[key]) == dict:
            nCurDepth = 1 + getTreeDepth(dictTrunk[key])
        else:
            nCurDepth = 1
        if nCurDepth > nMaxDepth:
            nMaxDepth = nCurDepth
    return nMaxDepth


def plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, nodeType):
    """
    plot a decision node or a leaf node depend on nodeType.
    ~~~~~~~~~~
    pltAxis: plot axis
    strNodeTxt: text in node box
    tplCntrPt: center coordinates of box
    tplPrntPt: starting coordinates of arrow
    nodeType: leafNode or decisionNode
    ~~~~~~~~~~
    N/A
    """
    pltAxis.annotate(strNodeTxt, xy=tplPrntPt, xycoords='axes fraction',
    xytext=tplCntrPt, textcoords='axes fraction',
    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plotMidText(pltAxis, cntrPt, parentPt, txtString):
    """
    add feature value in the middle of arrow
    ~~~~~~~~~~
    cntrPt:
    parentPt:
    txtString:
    ~~~~~~~~~~
    N/A
    """
    xMid = (parentPt[0]+cntrPt[0])/2.0
    yMid = (parentPt[1]+cntrPt[1])/2.0
    pltAxis.text(xMid, yMid, txtString)


def plotTree(dictTree, pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt, strNodeTxt):
    """
    plot tree recursivly
    ~~~~~~~~~~
    dictTree: decision tree
    pltAxis: axis used for plotting
    fTrunkLen: difference of y coordinates between two decision nodes
    fBrchLen: difference of y coordinates between two leafs
    tplCntrPt: coordinates of parent node  
    strNodeTxt: text in node box
    ~~~~~~~~~~
    N/A
    """
    #plot root node
    plotNode(pltAxis, strNodeTxt, tplCntrPt, tplPrntPt, decisionNode)
    #plot branch node
    tplPrntPt = tplCntrPt
    nNumKey = len(dictTree.keys())
    fMean = sum([x for x in range(nNumKey)])/nNumKey
    for i,key in enumerate(dictTree.keys()):
        tplCntrPt = (tplPrntPt[0]+(i-fMean)*fBrchLen, tplPrntPt[1]-fTrunkLen)
        plotMidText(pltAxis, tplCntrPt, tplPrntPt, key)
        if type(dictTree[key]) == dict:
            strNodeTxt = list(dictTree[key].keys())[0]
            plotTree(dictTree[key][strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)
        else:
            strNodeTxt = dictTree[key]
            plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, leafNode)


if __name__ == '__main__':
    dictTree = retrieveTree(2)
    matplotlib.rcParams['toolbar'] = 'none'
    pltAxis = plt.subplot(111, frameon=False,xticks=[], yticks=[])
    fBrchLen = 1/getNumLeafs(dictTree)
    fTrunkLen= 1/getTreeDepth(dictTree)
    tplCntrPt = (0.5,1)
    tplPrntPt = tplCntrPt
    strNodeTxt = list(dictTree.keys())[0]
    plotTree(dictTree[strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Machine Learning in Action is unique book that blends the foundational theories of machine learning with the practical realities of building tools for everyday data analysis. You'll use the flexible Python programming language to build programs that implement algorithms for data classification, forecasting, recommendations, and higher-level features like summarization and simplification. About the Book A machine is said to learn when its performance improves with experience. Learning requires algorithms and programs that capture data and ferret out the interesting or useful patterns. Once the specialized domain of analysts and mathematicians, machine learning is becoming a skill needed by many. Machine Learning in Action is a clearly written tutorial for developers. It avoids academic language and takes you straight to the techniques you'll use in your day-to-day work. Many (Python) examples present the core algorithms of statistical data processing, data analysis, and data visualization in code you can reuse. You'll understand the concepts and how they fit in with tactical tasks like classification, forecasting, recommendations, and higher-level features like summarization and simplification. Readers need no prior experience with machine learning or statistical processing. Familiarity with Python is helpful. What's InsideA no-nonsense introduction Examples showing common ML tasks Everyday data analysis Implementing classic algorithms like Apriori and Adaboos =================================== Table of ContentsPART 1 CLASSIFICATION Machine learning basics Classifying with k-Nearest Neighbors Splitting datasets one feature at a time: decision trees Classifying with probability theory: naïve Bayes Logistic regression Support vector machines Improving classification with the AdaBoost meta algorithm PART 2 FORECASTING NUMERIC VALUES WITH REGRESSION Predicting numeric values: regression Tree-based regression PART 3 UNSUPERVISED LEARNING Grouping unlabeled items using k-means clustering Association analysis with the Apriori algorithm Efficiently finding frequent itemsets with FP-growth PART 4 ADDITIONAL TOOLS Using principal component analysis to simplify data Simplifying data with the singular value decomposition Big data and MapReduce
要使用Python绘制决策树,可以按照以下步骤进行操作: 1. 导入所需的库:在Python中,我们需要导入一些必要的库来绘制决策树。这些库包括numpy、pandas、graphviz和sklearn。其中,graphviz是用于可视化决策树的关键库。可以使用以下代码导入这些库: ```python import numpy as np import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from sklearn.tree import export_graphviz import graphviz ``` 2. 加载数据集:为了绘制决策树,我们需要一个数据集。我们可以使用sklearn库中的load_iris函数加载一个经典的鸢尾花数据集作为示例。示例代码如下: ```python iris = load_iris() X = iris.data y = iris.target ``` 3. 拟合决策树模型:接下来,我们使用DecisionTreeClassifier()函数创建一个决策树模型,并使用fit()函数对其进行拟合。示例代码如下: ```python clf = DecisionTreeClassifier(random_state=42) clf.fit(X, y) ``` 4. 导出决策树绘制决策树之前,我们需要将决策树导出为DOT格式。我们可以使用export_graphviz函数将决策树导出为DOT格式的文件。示例代码如下: ```python dot_data = export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True) ``` 5. 可视化决策树:最后,我们可以使用graphviz库中的Source函数将DOT格式的决策树文件可视化。示例代码如下: ```python graph = graphviz.Source(dot_data) graph.render("decision_tree") # 可以将决策树保存为PDF或图片等格式 graph.view() # 在窗口中显示决策树 ``` 这样,我们就可以使用Python来绘制决策树了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值