机器学习之旅(二):决策树

机器学习之旅(二):决策树


决策树工作原理

  决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3, C4.5和C5.0生成树算法使用熵。这一度量是基于信息学理论中熵的概念。
决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
分类树(决策树)是一种十分常用的分类方法。他是一种监管学习,所谓监管学习就是给定一堆样本,每个样本都有一组属性和一个类别,这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。这样的机器学习就被称之为监督学习。

决策树的优缺点

  - 优点:计算复杂度不高,输出结果易于理解,对中间值得缺失不敏感,可以处理不相关特征数据。
  - 缺点:可能会产生过度匹配问题。
  - 使用数据范围:数值型和标称型。

决策树的一般流程

  (1)收集数据:可以使用任何方法。
  (2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
  (3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
  (4)训练算法:构造树的数据结构。
  (5)测试算法:使用经验树计算错误率。
  (6)使用算法:此步骤可以适用于任何监督学习算法,二使用决策树可以更好地理解数据的内在含义。


编写代码计算经验熵

  在编写代码之前,我们先对数据集进行属性标注。

  - 年龄:0代表青年,1代表中年,2代表老年;
  - 有工作:0代表否,1代表是;
  - 有自己的房子:0代表否,1代表是;
  - 信贷情况:0代表一般,1代表好,2代表非常好;
  - 类别(是否给贷款):no代表否,yes代表是。

  确定这些之后,我们就可以创建数据集,并计算经验熵了,代码编写如下:

#!/usr/bin/env python
# -*- coding:utf-8 -*-

from math import log

"""
函数说明:创建测试数据集

Parameters:
    无
Returns:
    dataSet - 数据集
    labels - 分类属性
Author:
    Jack Cui
"""
def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],         #数据集
            [0, 0, 0, 1, 'no'],
            [0, 1, 0, 1, 'yes'],
            [0, 1, 1, 0, 'yes'],
            [0, 0, 0, 0, 'no'],
            [1, 0, 0, 0, 'no'],
            [1, 0, 0, 1, 'no'],
            [1, 1, 1, 1, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [2, 0, 1, 2, 'yes'],
            [2, 0, 1, 1, 'yes'],
            [2, 1, 0, 1, 'yes'],
            [2, 1, 0, 2, 'yes'],
            [2, 0, 0, 0, 'no']]
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #分类属性
    return dataSet, labels                #返回数据集和分类属性

"""
函数说明:计算给定数据集的经验熵(香农熵)

Parameters:
    dataSet - 数据集
Returns:
    shannonEnt - 经验熵(香农熵)
Author:
    Jack Cui
"""
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)               #返回数据集的行数
    labelCounts = {}                        #保存每个label出现次数的字典
    for featVec in dataSet:                 #对每组特征向量进行统计
        currentLabel = featVec[-1]          #提取label信息
        if currentLabel not in labelCounts.keys():      #如果label没有放入统计次数的字典,就添加进去
            labelCounts[currentLabel] = 0               #label初始值为0
        labelCounts[currentLabel] += 1      #label计数
    shannonEnt = 0.0                        #香农熵
    for key in labelCounts:                 #计算香农熵
        prob = float(labelCounts[key])/numEntries       #选择该label的概率
        shannonEnt -= prob * log(prob,2)    #利用公式计算
    return shannonEnt                       #返回香农熵

if __name__ == '__main__':
    dataSet, features = createDataSet()
    print(dataSet)
    print(calcShannonEnt(dataSet))

  程序运行结果得出如下图所示,经验熵等于0.9709505944546686

这里写图片描述

编写代码创建决策树

  接下来创建决策树,代码编写如下

#!/usr/bin/env python
# -*- coding:utf-8 -*-

from math import log
import operator

"""
函数说明:计算给定数据集的经验熵(香农熵)

Parameters:
    dataSet - 数据集
Returns:
    shannonEnt - 经验熵(香农熵)
Author:
    Jack Cui
"""
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)               #返回数据集的行数
    labelCounts = {}                        #保存每个label出现次数的字典
    for featVec in dataSet:                 #对每组特征向量进行统计
        currentLabel = featVec[-1]          #提取label信息
        if currentLabel not in labelCounts.keys():      #如果label没有放入统计次数的字典,就添加进去
            labelCounts[currentLabel] = 0               #label初始值为0
        labelCounts[currentLabel] += 1      #label计数
    shannonEnt = 0.0                        #香农熵
    for key in labelCounts:                 #计算香农熵
        prob = float(labelCounts[key])/numEntries       #选择该label的概率
        shannonEnt -= prob * log(prob,2)    #利用公式计算
    return shannonEnt                       #返回香农熵

"""
函数说明:创建测试数据集

Parameters:
    无
Returns:
    dataSet - 数据集
    labels - 分类属性
Author:
    Jack Cui
"""
def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],         #数据集
            [0, 0, 0, 1, 'no'],
            [0, 1, 0, 1, 'yes'],
            [0, 1, 1, 0, 'yes'],
            [0, 0, 0, 0, 'no'],
            [1, 0, 0, 0, 'no'],
            [1, 0, 0, 1, 'no'],
            [1, 1, 1, 1, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [2, 0, 1, 2, 'yes'],
            [2, 0, 1, 1, 'yes'],
            [2, 1, 0, 1, 'yes'],
            [2, 1, 0, 2, 'yes'],
            [2, 0, 0, 0, 'no']]
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #分类属性
    return dataSet, labels                #返回数据集和分类属性

"""
函数说明:按照给定特征划分数据集

Parameters:
    dataSet - 待划分的数据集
    axis - 划分数据集的特征
    value - 需要返回的特征的值
Returns:
    无
Author:
    Jack Cui
"""
def splitDataSet(dataSet,axis,value):
    retDataSet = []                 #创建返回的数据集列表
    for featVec in dataSet:         #遍历数据集
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]             #去掉axis特征
            reducedFeatVec.extend(featVec[axis+1:])     #将符合条件的添加到返回的数据集
            retDataSet.append(reducedFeatVec)
    return retDataSet               #返回划分后的数据集

"""
函数说明:选择最优特征

Parameters:
    dataSet - 数据集
Returns:
    bestFeature - 信息增益最大的(最优)特征的索引值
Author:
    Jack Cui
"""
def chooseBestFeatureToSolit(dataSet):
    numFeatures = len(dataSet[0]) - 1           #特征数量
    baseEntropy = calcShannonEnt(dataSet)       #计算数据集的香农熵
    bestInfoGain = 0.0                          #信息增益
    bestFeature = -1                            #最优特征的索引值
    for i in range(numFeatures):                #遍历所有特征
        featList = [example[i] for example in dataSet]      #获取dataSet的第i个所有特征
        uniqueVals = set(featList)      #创建set集合{},元素不可重复
        newEntropy = 0.0                #经验条件熵
        for value in uniqueVals:        #计算信息增益
            subDataSet = splitDataSet(dataSet,i,value)      #subDataSet划分后的子集
            prob = len(subDataSet) / float(len(dataSet))    #计算子集的概率
            newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
        infoGain = baseEntropy - newEntropy                 #信息增益
        # print("第%d个特征的增益为%.3f"%(i,infoGain))        #打印每个特征的信息增益
        if (infoGain > bestInfoGain):                       #计算信息增益
            bestInfoGain = infoGain                         #更新信息增益,找到最大的信息增益
            bestFeature = i                                 #记录信息增益最大的特征的索引值
    return bestFeature                                      #返回信息增益最大的特征的索引值

"""
函数说明:统计classList中出现此处最多的元素(类标签)

Parameters:
    classList - 类标签列表
Returns:
    sortedClassCount[0][0] - 出现此处最多的元素(类标签)
Author:
    Jack Cui
"""
def majorityCnt(classList):
    classCount = {}
    for vote in classList:                  #统计classList中每个元素出现的次数
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    #根据字典的值降序排序
    return sortedClassCount                 #返回classList中出现次数最多的元素

"""
函数说明:创建决策树

Parameters:
    dataSet - 训练数据集
    labels - 分类属性标签
    featLabels - 存储选择的最优特征标签
Returns:
    myTree - 决策树
Author:
    Jack Cui
"""
def createTree(dataSet,labels,featLabels):
    classList = [example[-1] for example in dataSet]            #取分类标签(是否放贷:yes or no)
    if classList.count(classList[0]) == len(classList):         #如果类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:                #遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSolit(dataSet)                #选择最优特征
    bestFeatLabel = labels[bestFeat]        #最优特征标签
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel:{}}             #根据最优特征的标签生成树
    del(labels[bestFeat])                   #删除已经使用的特征标签
    featValues = [example[bestFeat] for example in dataSet]     #得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)            #去掉重复的属性值
    for value in uniqueVals:                #遍历特征,创建决策树
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),labels,featLabels)
    return myTree

"""
函数说明:使用决策树分类

Parameters:
    inputTree - 已经生成的决策树
    featLabels - 存储选择的最优特征标签
    testVec - 测试数据列表,顺序对应最优特征标签
Returns:
    classLabel - 分类结果
Author:
    Jack Cui
"""
def classify(inputTree,featLabels,testVec):
    firstStr = next(iter(inputTree))            #获取决策树节点
    secondDict = inputTree[firstStr]            #下一个字典
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


if __name__ == '__main__':
    dataSet, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataSet, labels, featLabels)
    print(myTree)
    testVec = [0,1]                                        #测试数据
    result = classify(myTree, featLabels, testVec)
    if result == 'yes':
        print('放贷')
    if result == 'no':
        print('不放贷')

  递归创建决策树时,递归有两个终止条件:第一个停止条件是所有的类标签完全相同,则直接返回该类标签;第二个停止条件是使用完了所有特征,仍然不能将数据划分仅包含唯一类别的分组,即决策树构建失败,特征不够用。此时说明数据纬度不够,由于第二个停止条件无法简单地返回唯一的类标签,这里挑选出现数量最多的类别作为返回值。
  
这里写图片描述
  从结果可以看到生成的决策树为:
  {‘有自己的房子’: {0: {‘有工作’: {0: ‘no’, 1: ‘yes’}}, 1: ‘yes’}}
  上述代码中还增加了一个分类功能classify函数,用于决策树分类。输入测试数据[0,1],它代表没有房子,但是有工作,得出结果为放贷,结果符合预期。
  

编写代码存储决策树

  构造出来的决策树我们需要通过一种方式将其存储起来,方便后续调用已经构造好的决策树,这里我们选择pickle方式来存储和载入决策树和分类节点,代码如下

#!/usr/bin/env python
# -*- coding:utf-8 -*-

from math import log
import operator
import pickle

"""
函数说明:计算给定数据集的经验熵(香农熵)

Parameters:
    dataSet - 数据集
Returns:
    shannonEnt - 经验熵(香农熵)
Author:
    Jack Cui
"""
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)               #返回数据集的行数
    labelCounts = {}                        #保存每个label出现次数的字典
    for featVec in dataSet:                 #对每组特征向量进行统计
        currentLabel = featVec[-1]          #提取label信息
        if currentLabel not in labelCounts.keys():      #如果label没有放入统计次数的字典,就添加进去
            labelCounts[currentLabel] = 0               #label初始值为0
        labelCounts[currentLabel] += 1      #label计数
    shannonEnt = 0.0                        #香农熵
    for key in labelCounts:                 #计算香农熵
        prob = float(labelCounts[key])/numEntries       #选择该label的概率
        shannonEnt -= prob * log(prob,2)    #利用公式计算
    return shannonEnt                       #返回香农熵

"""
函数说明:创建测试数据集

Parameters:
    无
Returns:
    dataSet - 数据集
    labels - 分类属性
Author:
    Jack Cui
"""
def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],         #数据集
            [0, 0, 0, 1, 'no'],
            [0, 1, 0, 1, 'yes'],
            [0, 1, 1, 0, 'yes'],
            [0, 0, 0, 0, 'no'],
            [1, 0, 0, 0, 'no'],
            [1, 0, 0, 1, 'no'],
            [1, 1, 1, 1, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [2, 0, 1, 2, 'yes'],
            [2, 0, 1, 1, 'yes'],
            [2, 1, 0, 1, 'yes'],
            [2, 1, 0, 2, 'yes'],
            [2, 0, 0, 0, 'no']]
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #分类属性
    return dataSet, labels                #返回数据集和分类属性

"""
函数说明:按照给定特征划分数据集

Parameters:
    dataSet - 待划分的数据集
    axis - 划分数据集的特征
    value - 需要返回的特征的值
Returns:
    无
Author:
    Jack Cui
"""
def splitDataSet(dataSet,axis,value):
    retDataSet = []                 #创建返回的数据集列表
    for featVec in dataSet:         #遍历数据集
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]             #去掉axis特征
            reducedFeatVec.extend(featVec[axis+1:])     #将符合条件的添加到返回的数据集
            retDataSet.append(reducedFeatVec)
    return retDataSet               #返回划分后的数据集

"""
函数说明:选择最优特征

Parameters:
    dataSet - 数据集
Returns:
    bestFeature - 信息增益最大的(最优)特征的索引值
Author:
    Jack Cui
"""
def chooseBestFeatureToSolit(dataSet):
    numFeatures = len(dataSet[0]) - 1           #特征数量
    baseEntropy = calcShannonEnt(dataSet)       #计算数据集的香农熵
    bestInfoGain = 0.0                          #信息增益
    bestFeature = -1                            #最优特征的索引值
    for i in range(numFeatures):                #遍历所有特征
        featList = [example[i] for example in dataSet]      #获取dataSet的第i个所有特征
        uniqueVals = set(featList)      #创建set集合{},元素不可重复
        newEntropy = 0.0                #经验条件熵
        for value in uniqueVals:        #计算信息增益
            subDataSet = splitDataSet(dataSet,i,value)      #subDataSet划分后的子集
            prob = len(subDataSet) / float(len(dataSet))    #计算子集的概率
            newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
        infoGain = baseEntropy - newEntropy                 #信息增益
        # print("第%d个特征的增益为%.3f"%(i,infoGain))        #打印每个特征的信息增益
        if (infoGain > bestInfoGain):                       #计算信息增益
            bestInfoGain = infoGain                         #更新信息增益,找到最大的信息增益
            bestFeature = i                                 #记录信息增益最大的特征的索引值
    return bestFeature                                      #返回信息增益最大的特征的索引值

"""
函数说明:统计classList中出现此处最多的元素(类标签)

Parameters:
    classList - 类标签列表
Returns:
    sortedClassCount[0][0] - 出现此处最多的元素(类标签)
Author:
    Jack Cui
"""
def majorityCnt(classList):
    classCount = {}
    for vote in classList:                  #统计classList中每个元素出现的次数
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    #根据字典的值降序排序
    return sortedClassCount                 #返回classList中出现次数最多的元素

"""
函数说明:创建决策树

Parameters:
    dataSet - 训练数据集
    labels - 分类属性标签
    featLabels - 存储选择的最优特征标签
Returns:
    myTree - 决策树
Author:
    Jack Cui
"""
def createTree(dataSet,labels,featLabels):
    classList = [example[-1] for example in dataSet]            #取分类标签(是否放贷:yes or no)
    if classList.count(classList[0]) == len(classList):         #如果类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:                #遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSolit(dataSet)                #选择最优特征
    bestFeatLabel = labels[bestFeat]        #最优特征标签
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel:{}}             #根据最优特征的标签生成树
    del(labels[bestFeat])                   #删除已经使用的特征标签
    featValues = [example[bestFeat] for example in dataSet]     #得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)            #去掉重复的属性值
    for value in uniqueVals:                #遍历特征,创建决策树
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),labels,featLabels)
    return myTree

"""
函数说明:使用决策树分类

Parameters:
    inputTree - 已经生成的决策树
    featLabels - 存储选择的最优特征标签
    testVec - 测试数据列表,顺序对应最优特征标签
Returns:
    classLabel - 分类结果
Author:
    Jack Cui
"""
def classify(inputTree,featLabels,testVec):
    firstStr = next(iter(inputTree))            #获取决策树节点
    secondDict = inputTree[firstStr]            #下一个字典
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

"""
函数说明:存储决策树

Parameters:
    inputTree - 已经生成的决策树
    filename - 决策树的存储文件名
Returns:
    无
Author:
    Jack Cui
"""
def storeTree(inputTree,filename):
    with open(filename,'wb') as fw:
        pickle.dump(inputTree,fw)

"""
函数说明:读取决策树

Parameters:
    filename - 决策树的存储文件名
Returns:
    pickle.load(fr) - 决策树字典
Author:
    Jack Cui
"""
def grabTree(filename):
    fr = open(filename,'rb')
    return pickle.load(fr)

if __name__ == '__main__':
    dataSet, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataSet,labels,featLabels)
    storeTree(myTree,'classifierStorage.txt')
    storeTree(featLabels,'featLabels.txt')

  运行上述代码后可以在当前Python文件相同目录下找到两个名为classifierStorage.txtfeatLabels.txt的txt文件。然后下次需要调用决策树的时候使用pickle.load方法载入即可,参见下面例子

#仅修改main函数
if __name__ == '__main__':
    myTree = grabTree('classifierStorage.txt')
    featLabels = grabTree('featLabels.txt')
    print(myTree)
    print(featLabels)
    testVec = [0, 1]
    result = classify(myTree, featLabels, testVec)
    if result == 'yes':
        print('放贷')
    if result == 'no':
        print('不放贷')

  运行结果如下图
这里写图片描述
  从上述结果中,我们可以看到,我们顺利加载了存储决策树和分类节点的二进制文件并顺利预测结果。
  

编写代码创建树形图

  决策树的主要优点就是直观易于理解,如果不能将其直观地显示出来,就无法发挥其优势。在这里我们将使用Matplotlib库来创建树形图。首先创建一个新的Python文件,编写以下代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import trees
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt

"""
函数说明:获取决策树叶子结点的数目

Parameters:
    myTree - 决策树
Returns:
    numLeafs - 决策树的叶子结点的数目
Author:
    Jack Cui
"""
def getNumLeafs(myTree):
    numLeafs = 0                        #初始化叶子
    firstStr = next(iter(myTree))
    #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,
    #可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr]       #获取下一组字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':    #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

"""
函数说明:获取决策树的层数

Parameters:
    myTree - 决策树
Returns:
    maxDepth - 决策树的层数
Author:
    Jack Cui
"""
def getTreeDepth(myTree):
    maxDepth = 0                        #初始化决策树深度
    firstStr = next(iter(myTree))
    #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,
    #可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr]       #获取下一个字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':    #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth        #更新层数
    return maxDepth

"""
函数说明:绘制结点

Parameters:
    nodeTxt - 结点名
    centerPt - 文本位置
    parentPt - 标注的箭头位置
    nodeType - 结点格式
Returns:
    无
Author:
    Jack Cui
"""
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    arrow_args = dict(arrowstyle="<-")                #定义箭头格式
    font = FontProperties(fname=r"C:\windows\fonts\simsun.ttc", size=14)             #设置中文字体
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,     #绘制结点
                            textcoords='axes fraction', va="center", ha="center",
                            bbox=nodeType, arrowprops=arrow_args, FontProperties=font)

"""
函数说明:标注有向边属性值

Parameters:
    cntrPt、parentPt - 用于计算标注位置
    txtString - 标注的内容
Returns:
    无
Author:
    Jack Cui
"""
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]              #计算标注位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

"""
函数说明:绘制决策树

Parameters:
    myTree - 决策树(字典)
    parentPt - 标注的内容
    nodeTxt - 结点名
Returns:
    无
Author:
    Jack Cui
"""
def plotTree(myTree,parentPt,nodeTxt):
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")                  #设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")                        #设置叶结点格式
    numLeafs = getNumLeafs(myTree)                                      #获取决策树叶结点数目,决定数的宽度
    depth = getTreeDepth(myTree)                                        #获取决策树层数
    firstStr = next(iter(myTree))                                       #下个字典
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)   #中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)                              #标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)                  #绘制结点
    secondDict = myTree[firstStr]                                       #下一个字典,也就是继续绘制子结点
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD                 #y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':                    #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            plotTree(secondDict[key], cntrPt, str(key))                 #不是叶结点,递归调用继续绘制
        else:                                                           #如果是叶结点,绘制叶结点,并标注有向边属性
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

"""
函数说明:创建绘制面板

Parameters:
    inTree - 决策树(字典)
Returns:
    无
Author:
    Jack Cui
"""
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')                          #创建fig
    fig.clf()                                                       #清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)     #去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))                    #获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))                   #获取决策树层数
    plotTree.xOff = -0.5/plotTree.totalW                            #x偏移
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')                                #绘制决策树
    plt.show()                                                      #显示绘制结果

if __name__ == '__main__':
    dataSet, labels = trees.createDataSet()
    featLabels = []
    myTree = trees.createTree(dataSet, labels, featLabels)
    print(myTree)
    createPlot(myTree)

  输出结果如下图
  
这里写图片描述  

  到这里为止,我们已经学习了如果构造决策树以及绘制树形图的方法。
  
感谢代码原作者:Jack-Cui
原文链接:http://blog.csdn.net/c406495762/article/details/75663451

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值