《机器学习实战》决策树(ID3算法)

参考博文(1)(2)


本章内容(使用Python3.6实现)

  • 决策树简介
  • 在数据集中度量一致性
  • 使用递归构造决策树
  • 使用matplotlib绘制树形图

关于决策树,我们首先讨论构造决策树的方法,以及如何编写构造树的Python代码;接着提出一些度量算法成功率的方法;最后使用递归建立分类器,并且使用Matplotlib绘制决策树图。构造完成决策树分类器之后,我们将输入一些隐形眼睛的处方数据,并由决策树分类器预测需要的镜片类型。 

3.1 决策树的构造

决策树

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

缺点:可能会产生过度匹配问题。

适用数据类型:数值型和标称型

在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。

创建分支的伪代码函数createBranch() 如下所示:

检测数据集中的每个子项是否属于同一分类:
    if so return 类标签;
    else:
        寻找划分数据集的最好特征
        划分数据集
        创建分支节点
            for 每个划分的子集
                调用函数creatBranch并增加返回结果到分支节点中
        return 分支节点

决策树的一般流程

1)收集数据:可以使用任何方法

2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。

3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。

4)训练算法:构造树的数据结构。

5)测试算法:使用经验树计算错误率。

6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

       本书采用ID3算法划分数据集,该算法处理如何划分数据集,何时停止划分数据集。每次划分数据集时我们只选取一个特征属性,如果训练集中存在20个特征,第一次选择哪个特征作为划分的参考属性?回答这个问题我们必须采用量化的方法判断如何划分数据。以下表3-1 数据为例。

表3-1 海洋生物数据
              不浮出水面是否可以生存    是否有脚蹼   属于鱼类
1                               是           是         是
2                               是           是         是
3                               是           否         否
4                               否           是         否
5                               否           是         否

3.1.1 信息增益

       划分数据集的大原则是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息。在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

                                                         l\left ( x_{i} \right )= -log_{2}p\left ( x_{i} \right )   ,   其中p(x_{i})是选择该分类的概率。

                                                           H = -\sum_{i=1}^{n}p(x_{i})log_{2}p(x_{i}),其中n是分类的数目。

# trees.py
# 使用Python计算信息熵
from math import log
def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

# 创建一个数据字典
def createDataSet():
    dataSet = [[1,1,'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no'],]
    labels = ['no surfacing','flippers']
    return dataSet,labels

熵越高,则混合的数据也越多。得到熵之后,我们就可以按照获取最大信息增益的方法划分数据集。

3.1.2 划分数据集

     我们学习了如何度量数据集的无序程度,分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断当前是否正确地划分了数据集。我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。

# 划分数据集,其中axis为划分数据集的特征,value为需要返回的特征的值
def splitDataSet(dataSet,axis,value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

 接下来我们将遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。

# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1   #数据集包含特征的总数
    baseEntropy = calcShannonEnt(dataSet)   #原始数据集的信息熵
    bestInfoGain = 0.0; bestFeature = -1     # 最佳信息增益,最佳划分特征
    for i in range(numFeatures):  # iterate over all the features
        featList = [example[i] for example in dataSet]
                        # create a list of all the examples of this feature
        uniqueVals = set(featList)  # 集合内元素不重合,这表示某特征的特征值所组成的集合
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy  # calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):  # compare this to the best gain so far
            bestInfoGain = infoGain  # if better than current best, set to best
            bestFeature = i
    return bestFeature  # returns an integer

信息增益是熵的减少或者是数据无序度的减少,大家肯定对于将熵用于度量数据无序度的减少更容易理解。

3.1.3 递归构建决策树

第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。然而,如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时,我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子节点的分类。

# 引入operator,多数表决决定该叶子节点的分类
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(),
                              key=operator.itemgetter(1),reverse=True)
    # sortedClassCount = sorted(classCount,key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

# 创建树的函数代码,采用递归的原则处理数据集
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]  # 类别完全相同则停止继续划分
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)   #遍历完所有特征时返回出现次数最多的类别
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLable = labels[bestFeat]
    myTree = {bestFeatLable:{}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

3.2 在Python中使用Matplotlib注解绘制树形图

3.2.1 Matplotlib注解

# treePlotter.py
# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 绘制带箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
                            xytext=centerPt,textcoords='axes fraction',
                            va='center',ha='center',bbox=nodeType,
                            arrowprops=arrow_args)
#创建节点
def createPlot():
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111,frameon=False)
    plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

运行结果如下图所示:

函数plotNode的例子
函数plotNode的例子

 

3.2.2 构造注解树

 构造一个完整的树需要掌握一些技巧。虽然我们有x,y坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便正确确定x轴的长度;知道树有多少层,以便可以正确确定y轴的高度。这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶结点的数目和树的层数。

完整代码如下:

# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")

# 获取叶结点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            #test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

# 绘制带箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
                            xytext=centerPt,textcoords='axes fraction',
                            va='center',ha='center',bbox=nodeType,
                            arrowprops=arrow_args)

# 在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
    # createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center",
    #                     rotation=30)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #计算宽与高
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks
    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
# #创建节点
# def createPlot():
#     fig = plt.figure(1,facecolor='white')
#     fig.clf()
#     createPlot.ax1 = plt.subplot(111,frameon=False)
#     plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
#     plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
#     plt.show()

# 输出预先存储的树信息,避免了每次测试代码时都要从数据中创建树的麻烦
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1:
                      {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

if __name__ == '__main__':
    # print(createPlot())
    myTree = retrieveTree(0)
    print(getNumLeafs(myTree))
    print(getTreeDepth(myTree))
    print(createPlot(myTree))
    myTree['no surfacing'][3] = 'maybe'
    print(myTree)
    print(createPlot(myTree))

如图所示:

简单数据集绘制的树形图
简单数据集绘制的树形图

 3.3 测试和存储分类器

 3.3.1 测试算法:使用决策树执行分类

# 使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in list(secondDict.keys()):
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

 3.3.2 存储决策树

# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')  # 用pickle序列化后的是二进制,所以此处用wb,
                                # 以二进制写入,不然默认以字符串写入会出错
    fw.write(pickle.dumps(inputTree))
    fw.close()
def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')  # 用'rb'
    data = pickle.loads(fr.read())
    fr.close()
    return data


# 也可以使用json模块
def storeTree(inputTree, filename):
    import json
    fw = open(filename, 'w')  # 只需要'w'
    fw.write(json.dumps(inputTree))
    fw.close()
def grabTree(filename):
    import json
    fr = open(filename, 'r')  # 只需要'r'
    data = json.loads(fr.read())
    fr.close()
    return data

 示例一:使用决策树预测隐形眼镜类型

'''
@Project -> File   :ML_in_action -> testLenses
@IDE    :PyCharm
@Author :NatW
@Date   :2019/11/1 21:16'''

import trees
import treePlotter
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree = trees.createTree(lenses,lensesLabels)
print(lensesTree)
print(treePlotter.createPlot(lensesTree))

 结果图如下:

由ID3算法产生的决策树
由ID3算法产生的决策树

倘若匹配选项过多,称之为“过度匹配”。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。 


小结:本章使用的算法成为ID3,它是一个很好的算法但并不完美。ID3算法无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数值,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题。还有其他的决策树的构造算法,最流行的有C4.5和CART,在讨论回归问题时会介绍CART算法。

       决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,我们首先需要测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。ID3算法可以用于划分标称型数据集。

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值