一文搞懂决策树的ID3算法

决策树是机器学习中一个比较重要的算法,和其他机器学习算法不一样的是,你不懂过多的数学理论知识,也能理解这个算法的原理。前几天我看到一篇文章,大概内容是利用决策树来预测世界杯最终的冠军是谁,最终预测结果是巴西(巴西好像已经凉凉了),不过这并不能影响你学习今天的算法。好了,废话不多说,进入今天的主题。

一听决策树,就知道这个算法和树形数据结构有关。确实如此,决策树本质上是一个树形结构,但和我们熟知的二叉树又不太一样,它是一个多叉树。这是我在百度上找的一张样图

在这里插入图片描述

首先要说明一下决策树的两大用途:分类和回归。其实很多的机器学习算法都可以做分类与回归,比如支持向量机(SVM),它有支持向量分类机和支持向量回归机,关于支持向量机这里不做深究,以后会专门写一篇关于支持向量机的文章。本文只用决策树来解决分类问题,而不探讨回归(其实是我还没学)。

为了接下来更好的描述决策树的整个算法流程,我找了一个具体的例子,一步一步分析。
在这里插入图片描述

这张表是拥有15个训练样本的贷款申请数据集。四个特征分别是年龄、是否有工作、是否有自己的房子、信贷情况。最后一列表示是否同意此人贷款申请。接下来我们要做的就是根据这个样本数据集,做一个合理的决策,对于未知的申请人,是否同意申请贷款。

就目前来说,我们已知的只有样本数据集,然后利用某种算法,希望通过样本数据集来学习出一组分类规则,并且这个分类规则要与样本数据集的矛盾尽可能小,这样才算是一组合理的分类规则,决策树就是一个不错的选择。

决策树中除了叶节点以外的节点,都是代表的特征,通过特征来不断划分数据集,使得最终叶节点上的数据集是“纯的”(数据集标签要么全为是,要么全为否),如果能选出这样的特征,那么我们就能轻易的判断未知样本到底该属于哪一类。比如下面这种特征选择。

在这里插入图片描述

不过难就难在我们应该如何选取最佳特征来划分数据集。这个就要从信息论的知识说起了。

信息论中,有一个很重要的概念就是香农熵,下面直接给出随机变量的香农熵计算公式:

在这里插入图片描述

当香农熵越大时,不确定性就越大,我们所能获取的信息量就越少。为了更加直观的理解这一句话,我针对随机变量只有两种取值的时候进行讨论。

当n = 2时,表达式如下,画出函数图像:

在这里插入图片描述

可以看出,当概率为0.5的时候,熵值最大,也就代表这个时候所能获得的信息量就越少。那这是为什么呢?直观上可以这样理解,当有人告诉你X = 0和X = 1的概率都是0.5时,你觉得他说的是不是废话,你既然说他们都是等可能的,那我根本不能判断谁的取值可能性更大,也就是说这个取值在概率上讲,是没有偏向性的,那我得到的信息量自然就是最少的,也就是香农熵最大。本身熵这个概念就有混乱的含义,熵越大,就越混乱,我们就很难提取出有用的信息。

下面要介绍另一个概念,条件熵。顾名思义,跟条件概率有关,定义为:X给定条件下Y的条件概率分布的熵对X的数学期望,公式如下:
在这里插入图片描述
然后今天的主角就登场了——信息增益

特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的条件熵H(D,A)之差,即

在这里插入图片描述
这个定义简直不是人话,翻译成人话就是说:我们先计算出训练集D的熵,然后计算出在给定特征A的条件下,计算出条件熵H(D,A),由于熵能表示信息的混乱度,两个熵值之差那不就代表信息混乱减少的程度么,换句话说,也就是在给定特征A的条件下,我们得到的信息增加的程度,那就是信息增益。决策树中选定哪个特征划分数据集,也就是看哪个特征的信息增益最大,信息增益最大对应的特征就是最佳特征

现在知道了怎么选取最佳特征,剩下的就是递归创建决策树了,这一部分需要较强的数据结构知识,我不做过多的探究,直接贴上代码:

#计算香农熵
def calcShannonEnt(dataSet):

    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1

    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*math.log(prob, 2)

    return shannonEnt

#划分数据集
def splitDataSet(dataSet, axis, value):

    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#计算最佳特征
def chooseBestFeatureToSplit(dataSet):


    # 每个样例的特征数目
    numFeatures = len(dataSet[0]) - 1
    # 计算香农熵
    baseEntropy = calcShannonEnt(dataSet)
    # 初始化最佳信息增益
    bestInfoGain = 0.0
    # 初始化用来划分特征空间的最佳特征
    bestFeatures = -1


    for i in range(numFeatures):   # 遍历所有特征
        featList = [example[i] for example in dataSet]  # 把数据集中第i个特征存入列表
        uniqueVals = set(featList)  # 去除重复的特征取值
        newEntropy = 0.0  # 初始化条件熵
        for value in uniqueVals:  # 遍历第i个特征的所有可能取值
            subDataSet = splitDataSet(dataSet, i, value)  # 划分数据集
            prob = len(subDataSet)/float(len(dataSet))  # 计算该特征取值的频率
            newEntropy += prob*calcShannonEnt(subDataSet)  # 计算条件熵
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain  # 找出最佳信息增益
            bestFeatures = i  # 找出划分特征空间的最佳特征
    return bestFeatures

#返回递归终止条件
def majorityCnt(classList):

    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

#递归创建树
def createTree(dataSet, labels):

    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeat:{}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

对于决策树的ID3算法还有几点补充说明:

(1)递归的终止条件还有一种情况,当所有特征都已经用完的时候,也是递归的终止条件,这个时候的叶节点标签可能“不纯”,需要通过多数表决的方法决定返回值,这个类似于knn算法的思想。

(2)ID3算法有一个致命的缺点,它过分的要求决策树去匹配数据集,使得模型在很多时候会过于复杂,就产生了机器学习中一个致命的点——overffiting(过拟合),这个时候就需要降低模型复杂度,采用的技巧是——剪枝,这就涉及到后面的C4.5和CART算法,就更加复杂了,还涉及到动态规划的思想。

这篇文章主要是为了梳理最近学习决策树算法整个流程,其中肯定会有一些不恰当的描述,欢迎大家指正,一起交流,一起学习。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值