【机器学习实战学习笔记】之 3 决策树01(含Matplotlib模块介绍)

目录

一、决策树简介

1、引入

1.二十个问题的游戏

2、决策树

1.工作原理

2.算法特点

二、决策树的构造

1、步骤

2、伪代码函数createBranch()

3、决策树的一般流程

4、算法分析

5、信息增益

6、划分数据集

7、递归构建决策树

三、使用Matplotlib注解绘制树形图

1、决策树

2、Matplotlib注解

3、构造注解树


本学习笔记参考书目《机器学习实战》第三章。

本章所有本书对应代码及数据集下载请点击(下载链接)。

本文中博主自己写的代码,如有需要,请点击(GitHub页面)。

由于本章内容较多,分两部分分享,第一部分包括决策树的概念、构造和绘制。

一、决策树简介

1、引入

1.二十个问题的游戏

A思考某个事物,其他人对A提20个问题,A回答对或错。根据回答结果缩小猜测物体的范围。这个游戏原理与决策树原理类似。用户输入一系列数据,然后给出游戏答案。

2、决策树

1.工作原理

矩形代表判断模块;椭圆代表终止模块,表示得出结论可以终止运行;从判断模块引出的左右箭头是分支。

以下图为例,通过不断地提问及回答问题,不断接近最终的结果。

流程图形式的决策树

 

2.算法特点

1.K-近邻算法缺点:无法给出数据的内在含义。

2.决策树的优点:

(1)数据形式非常容易理解。

(2)决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,机器学习算法最终将使用这些机器从数据集中创造的规则。

(3)计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

(4)决策树给出结果往往可以匹敌在当前领域具有几十年工作经验的人类专家。

3.决策树的缺点:可能会产生过度匹配问题。

4.适用数据类型:数值型与标称型。

二、决策树的构造

1、步骤

1.评估特征,将原始数据集划分为几个数据子集。

2.寻找决定性特征,划分出最好结果。

3.划分完成后,数据子集会分布在第一个决策点的所有分支上。

4.判断分支下的数据类型:

(1)属于同一种类型:说明正确分类,无需进一步对数据集进行分割,

(2)不属于同一种类型,需要重复划分数据子集的过程,划分方法与划分原始数据集方法相同。直到所有具有相同类型的数据均在一个数据子集内。

2、伪代码函数createBranch()

if so return 类标签;
else 
    寻找划分数据集的最好特征
    划分数据集
    创建分支节点
        for 每个划分的子集
            调节函数createBranch并增加返回结果到分支节点中
    return 分支节点

注:伪代码是一个递归函数。

3、决策树的一般流程

(1)收集数据:可以使用任何算法。

(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。

(3)分析数据:可以使用任何方法,构造树完成之后,应该检查图形是否符合预期。

(4)训练算法:构造树的数据结构。

(5)测试算法:使用经验树计算错误率。

(6)使用算法:适用于任何监督学习算法,决策树可以更好地理解数据的内在含义。

4、算法分析

1.数据划分

1.由于依据某个属性划分数据会产生四个可能的值,所以将数据分为四块,并创建四个不同分支。

2.使用ID3算法划分数据集。

3.要考虑的问题:

(1)ID3算法如何划分数据集;

(2)何时停止划分数据集;

(3)第一次选择哪个特征作为划分的参考属性。

2.示例

1.数据示例

表3 - 1  海洋生物数据

2.数据分析

1.有五种海洋生物,生物分为两类:鱼类和非鱼类。

2.两类数据特征:不浮出水面是否可以生存、是否有脚蹼。

5、信息增益

1.划分数据集的大原则

将无序的数据变得更加有序。

2.信息增益相关概念

1.信息增益:组织杂乱无章数据的一种方法就是使用信息论度量信息,在划分数据之前使用信息论量化度量信息的内容,划分数据集之前之后信息发生的变化称为信息增益

2.熵:集合信息的度量方式称为香农熵或者简称为。不确定性越大,熵越高。

个人理解:将乱序数据转化为有序数据前后变化为信息增益,数据的信息的混乱程度叫)。

3.信息:如果待分类的事务可能划分在多个分类之中’ 则符号xi的信息定义为:

4.信息期望值:算所有类别所有可能值包含的信息期望值(n为分类的数目):

3.Python代码:计算给定数据集的香农熵

1.代码

#coding=utf-8

from math import log

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    
    #为所有可能的字创建字典 开始
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    #为所有可能的字创建字典 结束
    
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #log(prob,2)为以二为底求对数
    return shannonEnt

def createDataSet():
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

myData,labels = createDataSet()
print(myData)
print(calcShannonEnt(myData))

注:原书中代码有缩进问题。

2.运行结果

3.代码分析

熵越高,混合的数据也越多,通过添加更多分类观察熵,当定义的数据为如下数据时:

def createDataSet():
    dataSet = [[1,1,'maybe'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

得到的结果是:

熵的值变大,说明混合数据变多了。

6、划分数据集

1.度量划分数据集的熵

1.目的

判断当前是否正确地划分了数据集。

2.方法

对每个特征划分数据集的结果计算一次信息熵,然后判断划分数据集的最好的特征。

2.Python代码:按照给定特征划分数据集

1.extend()与append()的区别

通过一段代码,我们可以分析对比两个方法的不同,代码如下:

a = [1,2,3]
b = [4,5,6]

c = [1,2,3]
d = [4,5,6]

a.append(b)
c.extend(d)

print('a.append(b) =  : {}'.format(a))
print('c.extend(d) =  : {}'.format(c))

执行结果如下:

2.Python代码

#coding=utf-8

from math import log

def splitDataSet(dataSet,axis,value):
    retDataSet = [] #创建一个新的列表对象
    for featVec in dataSet:
        if featVec[axis] == value:
            #将符合特征的数据抽取出来
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def createDataSet():
    dataSet = [[1,1,'yes'],   #第一次测试时使用
    #dataSet = [[1,1,'maybe'], #第二次测试时使用
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

myData,labels = createDataSet()

print('myData : \n{}'.format(myData))
print('splitDataSet(myData,0,1) : \n{}'.format(splitDataSet(myData,0,1)))
print('splitDataSet(myData,0,0) : \n{}'.format(splitDataSet(myData,0,0)))

执行结果如下:

3.Python代码:最好的数据集划分方式

1.Python代码

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):
        #创建唯一的分类标签列表 开始
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        #创建唯一的分类标签列表 结束
        newEntropy = 0.0
        #计算每种划分方式的信息熵 开始
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        #计算每种划分方式的信息熵 结束
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            #计算最好的信息增益 开始
            bestInfoGain = infoGain
            bestFeature = i
            #计算最好的信息增益 结束
    return bestFeature

print('myData : \n{}'.format(myData))
print('the best feature to split : \n{}'.format(chooseBestFeatureToSplit(myData)))
print('myData : \n{}'.format(myData))

执行结果如下:

2.数据要求

(1)数据必须由一种列表元素组成的列表,所有列表元素都要有相同的数据长度;

(2)数据的最后一列或每个实例的最后一个元素是当前实例的类别标签。

数据集一旦满足上述要求,就可以在函数第一行判定当前数据集包含多少特征属性。

3.结果分析

第0个特征是最好的用于划分数据集的特征。

按照第一个特征值分类:特征值为1的海洋生物,两个属于鱼类,一个属于非鱼类;特征值为0的海洋生物全部属于非鱼类。

按照第二个特征值分类:特征值为1的海洋生物,两个属于鱼类,两个属于非鱼类;特征值为0的海洋生物只有一个非鱼类。

7、递归构建决策树

1.决策树子功能模块工作原理

1.得到原始数据集,基于最好的属性值划分数据集(可能存在大于两个分支的数据集的划分)。

2.第一次划分后,数据被向下传递到树分支的下一个节点,在该节点上再次划分数据(即采用递归的原则处理数据集)。

3.递归结束的条件:(1)程序遍历完所有划分数据集的属性;(2)每个分支下的所有实例都具有相同的分类。

图3-2 划分数据集时的数据路径

2.测试算法

1.Python代码

def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] +=1
    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
    return sortedClassCount[0][0]

#创建树
def createTree(dataSet,labels):#两个参数:数据集和标签列表(包含数据集中所有特征的标签)
    classList = [example[-1] for example in dataSet] #列表变量,包含数据集的所有类标签
    if classList.count(classList[0]) == len(classList): #所有的类标签完全相同,递归函数停止
        return classList[0]  # 直接返回该类标签0
    if len(dataSet[0]) == 1: # 使用完了所有特征,仍不能将数据集划分成仅包含唯一类别的分组
        return majorityCnt(classList) # 返回出现次数最多者(无法简单返回唯一类标签)
    bestFeat = chooseBestFeatureToSplit(dataSet) #最优数据集特征
    bestFeatLabel = labels[bestFeat] # 最优特征标签
    myTree = {bestFeatLabel:{}} #定义字典myTree变量存储树的所有信息

    #得到列表包含的所有属性值 开始
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    #得到列表包含的所有属性值 结束

    #遍历当前选择特征包含的所有属性值
    for value in uniqueVals:
        subLabels = labels[:] #复制了类标签,并将其存储在新列表变量中(目的:为保证调用createTree函数)
        #在每个数据集划分上递归调用函数creatTree(),得到的返回值插入到字典变量中。
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)

    return myTree

2.执行结果

三、使用Matplotlib注解绘制树形图

1、决策树

1.优点

直观,易于理解。

2.范例

2、Matplotlib注解

1.annotation注解工具

1.作用:在数据图形上添加文本注释。

2.示例:

对(0.2,0.1)位置的点的描述信息放在(0.35,0.3),并用箭头指向数据点(0.2,0.1)。

#示例代码
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

 

 2.Python代码:使用文本注解绘制树节点

为了使用前面的分类器,需要先将图像格式化处理为一个向量。即将32×32的二进制图像转换为1×1024的向量。

1.第一个版本Python代码

import matplotlib.pyplot as plt

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #dict()用于创建一个字典, boxstyle="sawtooth"表示注解框的边缘是波浪线,fc=”0.8” 是颜色深度
leafNode = dict(boxstyle="round4", fc="0.8") # 同上
arrow_args = dict(arrowstyle="<-") # 设置箭头的样式

def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解,centerPt为节点中心坐标,parentPt 为起点坐标
    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

#createPlot版本一
def createPlot():
    fig = plt.figure(1, facecolor='white') # 设置背景色
    fig.clf() # 清空画布
    #axprops = dict(xticks=[], yticks=[])
    createPlot.axl = plt.subplot(111, frameon=False) #表示图中有1行1列,绘图放在第几列, 有无边框, subplot(111)和subplot(1,1,1)是相同的
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) # 第一个坐标是注解的坐标 第二个坐标是点的坐标 #
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

2.运行代码

createPlot()

3.执行结果

函数plotNode示例

3、构造注解树

1.获取叶节点数目及树的层数

注:这段代码有与原书不一样之处,原因在于Python版本不同。主要是以下两个方面:

1.firstStr 的创建不同

具体问题请点击:(firstStr创建问题

2.if判断语句不同

具体问题请点击:(if判断语句不同

#获取叶节点个数
def getNumLeafs(myTree):
    numLeafs = 0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]#找到输入的第一个元素
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs +=1
    return numLeafs

#获取树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]#找到输入的第一个元素
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

2.输出预先存储的树信息

1.Python代码

#输出预先存储的树信息,避免每次测试代码都从数据中创建树的麻烦
def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                   {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
                   ]
    return listOfTrees[i]

print('retrieveTree(0) : \n{}'.format(retrieveTree(0)))
print('retrieveTree(1) : \n{}'.format(retrieveTree(1)))

myTree = retrieveTree(0)
print('树的叶子结点个数为:\n{}'.format(getNumLeafs(myTree)))
print('树的深度为: \n{}'.format(getTreeDepth(myTree)))

2.运行结果

3.构造注解树

1.Python代码

#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)

#画一棵树
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree) #计算树的宽
    depth = getTreeDepth(myTree) #计算树的高
    firstStr = list(myTree.keys())[0]
    plotTree.totalW = float(getNumLeafs(myTree))  #存储树的宽度
    plotTree.totalD = float(getTreeDepth(myTree)) #存储树的深度
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    #cntrPt = (plotTree.xOff + (0.5/plotTree.totalW + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt) #标记子节点属性值
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#createPlot 版本二
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axpropps = dict(xticks = [],yticks = [])
    createPlot.axl = plt.subplot(111, frameon = False, **axpropps)
    plotTree.totalW = float(getNumLeafs(inTree))  #存储树的宽度
    plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
    plotTree.xOff = -0.5/plotTree.totalW  #xOff 与 yOff追踪已经绘制的节点位置以及下一个节点的恰当位置。
    plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

2.运行结果

注:我在执行过程中发现,图像如下图一所示,无法完全展示,所以我点击设置调整了图形大小及位置,调正后如下图二。

图一:修改前
图二:修改后

3.变更字典

#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)

#画一棵树
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree) #计算树的宽
    depth = getTreeDepth(myTree) #计算树的高
    firstStr = list(myTree.keys())[0]
    plotTree.totalW = float(getNumLeafs(myTree))  #存储树的宽度
    plotTree.totalD = float(getTreeDepth(myTree)) #存储树的深度
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    #cntrPt = (plotTree.xOff + (0.5/plotTree.totalW + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt) #标记子节点属性值
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#createPlot 版本二
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axpropps = dict(xticks = [],yticks = [])
    createPlot.axl = plt.subplot(111, frameon = False, **axpropps)
    plotTree.totalW = float(getNumLeafs(inTree))  #存储树的宽度
    plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
    plotTree.xOff = -0.5/plotTree.totalW  #xOff 与 yOff追踪已经绘制的节点位置以及下一个节点的恰当位置。
    plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

2.代码运行

myTree = retrieveTree(0)
myTree['no surfacing'][3] = 'maybe'
print('myTree : \n{}'.format(myTree))
createPlot(myTree)

3.运行结果

myTree : 
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}

 

 

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值