《机器学习实战》3.决策树算法

目录

1.决策树概述:

2.决策树构造

1.1 信息增益

2.2划分数据集

2.3 递归构建决策树

3.在python中使用matplotlib注解绘制树形图

3.1 Matplotlib注解

 3.2构造注解树

4.测试和存储分类器

4.1测试算法:使用决策树执行分类

4.2 使用算法:决策树的存储

4.4示例:使用决策树预测隐形眼镜类型


本博客涉及相关代码和数据

提取码:wdyb

1.决策树概述:

关于决策树进行分类的算法代码的应用

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据

缺点:可能产生过于匹配的问题

适用数据类型:数值型和标称型

解决问题:当前数据集上哪个特征在划分数据分类时起决定作用

流程:

1.收集数据  

2.准备数据  

3.分析数据  

4.训练算法:构造树的数据结构  

5.测试算法:使用经验树计算错误率  

6.使用算法  

2.决策树构造

1.1 信息增益

划分数据集的最大原则就是:将无序的数据变得更加有序,在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

熵定义为信息的期望值,在明晰这个概念之前,我们必须知道信息的定义。如果待分类的事务可能划分在多个分类之中,则符号的信息定义为

 其中p(xi)是选择该分类的概率。

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

其中n是分类的数目。

下面我们将计算给定数据的信息熵:

#使用ID3算法划分数据集
from math import log

# 计算给定数据集的香农熵->信息期望值   
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    for featVec in dataSet:
        # 拿到分类标签
        currentLabel=featVec[-1]
        # 查看当前分类标签是否存在,如果不存在则创建一个新的标签,
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        # 如果存在则直接加1
        labelCounts[currentLabel]+=1
    # 预设香农熵为0
    ShannonEnt=0.0
    for key in labelCounts:
        # 拿到概率值
        prob=float(labelCounts[key])/numEntries
        # 计算香农熵
        ShannonEnt-=prob*log(prob,2)
    return ShannonEnt

创建createDataSet()函数得到简单鉴定数据集

def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flipper']
    return dataSet,labels

调用上面的函数并计算信息熵得到:

myDat,Labels=createDataSet()
calcShannonEnt(myDat)

得到的结果为:

 熵越高,则混合的数据也越多,在数据集中添加更多的分类,看熵是如何变化,增加第三个名为maybe的分类,测试熵的变化:

myDat[0][-1]='maybe'
calcShannonEnt(myDat)

得到的结果为:

 得到熵之后,就可以按照获取最大信息增益的方法划分数据集。

2.2划分数据集

分类算法除了需要测量信息熵,还需要划分数据集,度量花费数据集的熵,以便判断当前是否正确的划分了数据集。对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分呢方式。

按照给定特征划分数据集:

# 按照给定特征划分数据集
#      待划分的数据集  划分数据集的特征  特征的返回值
def splitDataSet(dataSet,axis,value):
    # 创建一个新的列表对象
    retDataSet=[]

    # 将符合特征的数据抽取出来
    for featVec in dataSet:
        if featVec[axis]==value:
            # feature[a:b]含a不含b,下面两句代码相当于去掉axis处的数据
            reducedFeatVec=featVec[:axis]
            # 将符合特征的数据以数据项的形式加上去
            reducedFeatVec.extend(featVec[axis+1:])
            # 将符合特征的数据在后面以列表的形式加上去
            retDataSet.append(reducedFeatVec)
    return retDataSet

关于append函数和extend函数的用法区分:

append函数是直接在后面加上一个完整的列表

extend函数是直接在后面把所有的元素加上

测试函数splitDataSet()

splitDataSet(myDat,0,1)
splitDataSet(myDat,0,0)

分别得到结果:

接下来,遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。熵计算将会告诉我们如何划分数据集是最好的数据组织方式

# 选择最好的数据集划分方式
# 选取特征,划分数据集
def chooseBestFeatureToSplit(dataSet):
    # 获得信息的数量   不加标签
    numFeatures=len(dataSet[0])-1
    # 计算当前数据的香农熵
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    # 遍历每一组信息
    for i in range(numFeatures):
        # 创建唯一的分类标签列表
        # 按列遍历:example[i] for example in dataSet
        # 将dataSet中的数据先按行依次放入example中,然后取得example中的example[i]元素,放入列表featList中
        featList=[example[i] for example in dataSet]
        # 除去重复元素
        uniqueVals=set(featList)
        newEntropy=0.0
        # 计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i ,value)
            # 得到概率
            prob=len(subDataSet)/float(len(dataSet))
            # 计算一起的期望
            newEntropy+=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        # 计算最好的信息增益
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature

 调用函数,得到最好划分方式的结果:

temp=chooseBestFeatureToSplit(myDat)
temp

输出结果为:

0

代码运训告诉我们,第0个特征是最好的用于划分数据集的特征

2.3 递归构建决策树

使用分类名称的列表,然后创建键值为classList中唯一值的数据字典,字典对象存储了classList中每个类标签出现的频率,最后利用operator操作键值排序字典,并返回出现次数最多的分类名称

import operator

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        # .keys()函数返回字典中的所有键所组成的一个可迭代序列,即所有表头序列
        if vote not in classCount.keys():
            # 如果该表头不存在,则创建一个表头
            classCount[vote]=0
        # 如果存在,直接加1即可
        classCount[vote]+=1
        # 然后将字典里的所有项从大到小排列
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    # 返回出现次数最多的表头
    return sortedClassCount[0][0]

 循环递归创建决策树

# 创建树
def createTree(dataSet,labels):
    # 得到数据的倒数第一列,即标签
    classList=[example[-1] for example in dataSet]
    
    # 类别完全相同则停止继续划分
    # 如果只有一个值?
    if classList.count(classList[0])==len(classList):
        return classList[0]
    # 遍历所有特征时,返回次数最多的
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    # 得到最好的划分数据集的方式
    bestFeat=chooseBestFeatureToSplit(dataSet)
    # 拿到最好方式的标签
    bestFeatLable=labels[bestFeat]
    myTree={bestFeatLable:{}}

    # 得到列表包含的所有属性值
    # 先删除原来的标签列表
    del(labels[bestFeat])
    # 拿到最好划分方式的那一列数据
    featValues=[example[bestFeat] for example in dataSet]
    # 去除重复项
    uniqueVals=set(featValues)
    # 遍历这些特征值
    for value in uniqueVals:
        subLabels=labels[:]
        # 递归调用
        myTree[bestFeatLable][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

测试生成的树:

myDat,labels=createDataSet()
myDat
myTree=createTree(myDat,labels)
myTree

得到测试的结果:

 生成树之后,我们使用Matplotlib包将生成的树结构画出来

3.在python中使用matplotlib注解绘制树形图

对于树来说,字典的表示非常不易于理解,而且直观绘制图形也比较困难。  

因此我们使用Matplotlib库创建树形图。决策树的优势就是直观易于理解,如果不能将其直观的显示出来,就无法发挥其优势  

3.1 Matplotlib注解

Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释

import matplotlib.pyplot as plt

decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    # axl.annotate获得对绘制箭头的访问
    #                      标注文本  点坐标                                 目标坐标                                                            点类型         箭头形状
    createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)

def createPlot():
    # 创建绘画框
    fig=plt.figure(1,facecolor='white')
    # 清除figure坐标轴
    fig.clf()
    # 设置一个多图展示,但是设置多图只有一个
    createPlot.axl=plt.subplot(111,frameon=False)
    #          标签值           目标坐标    点坐标     决定节点
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

调用函数

createPlot()

得到输出的结果为:

 3.2构造注解树

为了能够在坐标纸上画出一个结构可变的树,应该首先确定树的宽和高

# 获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs=0
    # 获得myTree的第一个键值,即第一个特征
    firstStr=list(myTree.keys())[0]
    # 得到第一个特征分类的结果
    secondDict=myTree[firstStr]
    # 向下递归一级,遍历其他特征
    for key in secondDict.keys():
        # 如果还没有到达最后一级
        if type(secondDict[key]).__name__=='dict':
            # 递归得到下一级的
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs

# 获取树的层数
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

然后根绝得到的树的宽和高,绘制树形图

# 在父子节点之间填充文本信息
#          子节点坐标   父节点坐标   文本信息
def plotMidText(cntrPt,parentPt,txtString):
    # 拿到父子节点中间的坐标,并添加文本信息
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    # 添加文本信息
    createPlot.axl.text(xMid,yMid,txtString)

# 全局变量plotTree.totalW存储树的宽度
# 全局变量plotTree.totalD存储树的深度
#利用这两个变量计算节点的摆放位置,使树放在中心位置
# 使用两个全局变量plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置
# 绘制树形图
def plotTree(myTree,parentPt,nodeTxt):
    
    # 拿到叶子节点的数目和树的深度来计算宽与高
    # 得到叶子节点的数目
    numLeafs=getNumLeafs(myTree)
    # 得到树的深度:没啥用?※※※※※※※※※
    depth=getTreeDepth(myTree)
    # 得到第一个表头序列
    firstStr=list(myTree.keys())[0]
    # 得到子节点坐标
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    # 标记子节点属性值
    plotMidText(cntrPt,parentPt,nodeTxt)
    # 画点
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    # 得到第二个表头序列的值
    secondDict=myTree[firstStr]
    # 减少y偏移
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        # 如果不是叶子节点,继续向下递归
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        # 如果是叶子节点,则直接画图
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            # 画图
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            # 写标注
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD


def createPlot(inTree):
    # 创建画框
    fig=plt.figure(1,facecolor='white')
    # 清空绘图区
    # fig.clf()
    # 定义坐标轴数字为空
    axprops=dict(xticks=[],yticks=[])
    # 定义树的宽度
    plotTree.totalW=float(getNumLeafs(myTree))
    # 定义树的高度
    plotTree.totalD=float(getTreeDepth(myTree))
    # x的偏移
    plotTree.xOff=-0.5/plotTree.totalW
    # y的偏移
    plotTree.yOff=1.0
    # subplot()函数解析
    # 第一个参数:三个独立的整数来描述子图的位置信息。行数、列数和索引值,子图将分布在行列的索引位置上。索引从1开始,从右上角增加到右下角。
    # 可以选择子图的类型,比如选择polar,就是一个极点图。默认是none就是一个线形图。
    # 第二个参数frameon如果选择true,就是一个极点图
    # 第三个参数,设置轴的标签
    createPlot.axl=plt.subplot(111,frameon=False,**axprops)
    
    # 画树
    plotTree(inTree,(0.5,1.0),'')
    # 展示图像
    plt.show()

调用画树的函数:

createPlot(myTree)

得到画出来的树的结果:

 

4.测试和存储分类器

使用决策树构建分类器、以及实际应用中如何存储

4.1测试算法:使用决策树执行分类

依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要决策树以及用于构造树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类

使用决策树的分类函数

# 使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
    # 得到树的第一个键值
    firstStr=list(inputTree.keys())[0]
    secondDict=inputTree[firstStr]
    print(secondDict)
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            # 如果不是叶子节点,继续向下搜索
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            # 如果是叶子节点,返回得到的标签
            else:
                classLabel=secondDict[key]
    return classLabel

测试分类结果如下:

mtDat,labels=createDataSet()
labels
print(classify(myTree,labels,[1,1]))
print(classify(myTree,labels,[1,0]))

得到输出的结果为:

4.2 使用算法:决策树的存储

构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据量很大,将会耗费很多计算时间。然而使用创建好的决策树解决分类问题,则可以很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用已经够造好的决策树    

为了解决这个问题,需要使用python模块pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。任何对象都可以进行序列化操作,字典对象也不例外

# 使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb')
    # 将inputTree对象放到文件fw中去
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr=open(filename,'rb')
    return pickle.load(fr)

调用函数:

storeTree(myTree,'classifierStorage.txt')
grabTree('classifierStorage.txt')

得到结果为:

4.4示例:使用决策树预测隐形眼镜类型

本节通过一个例子讲解决策树如何预测患者需要佩戴的隐形眼睛类型。使用小数据集,我们就可以利用决策树学到很多知识:眼科医生是如何判断患者需要佩戴的镜片类型;一旦理解了决策树的工作原理,我们甚至也可以帮助人们判断需要佩戴的镜片类型

1.收集数据:提供的文本文件  

2.准备数据:解析tab键分割的数据行  

3.分析数据:快速检查数据,确保正确地解析内容,使用createPlot函数绘制最终的树形图    

4.训练算法:使用createTree()函数  

5.测试算法:编写测试函数验证决策树可以正确分类给定的数据实例  

6.使用算法:存储树的数据结构,以便下次使用时无需重新构造树  

fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=createTree(lenses,lensesLabels)
lensesTree

得到的输出结果为:

 将得到的树的结构画出来:

createPlot(lensesTree)

画出树的结构如下:

(这里有个小bug,根节点多出来了一根箭头,没找到解决方法,就暂且放着了)

 如上图所示的决策树非常好的匹配了实验数据,然而可能匹配的选项太多了。将这种问题称为过度匹配,为了减少过度匹配的问题,可以进行裁剪决策树,但是在这一章先不讨论。

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

豆豆豆豆芽

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值