结合源码分析第三章中实现的Demo
运行环境:Anaconda——Jupyter Notebook
Python版本为:3.6.2(原书代码实现为2.x 所以在一些代码上略有改动)
阅读本博文你将获取:
1.决策树的基本思想
2.信息增益和熵的概念——本文中使用信息增益作为划分数据集的标准
3.全部的代码实现,且包含了大部分注释,便于初学者者理解
4.在最后的总结部分对决策树的优缺点做了总结
参考资料:Apachecn 专注于优秀项目维护的开源组织
前言
决策树(Decesion Tree)是一种基本的分类算法。它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布。其主要优点就是模型具有可读性,分类速度块。学习时,利用训练数据,根据损失函数最小化的原则建立决策树模型。旨在构建一个与训练数据拟合很好,并且复杂度小的决策树。因为从可能的决策树中直接选取最优决策树是NP问题,现实中采用启发方法学习次优的决策树。
决策树学习算法包括3部分:特征选择、树的生成和树的剪枝。常用算法有ID3(本文采用)、C4.5和CART算法。
决策树的生成:通常使用信息增益最大,信息增益比最大,或基尼指数最小作为特征选择的准则。决策树的生成往往通过计算信息增益或其他指标,从根节点开始,递归地产生决策树。这相当于用信息增益或其他准则不断地选取局部最优 的特征,或将训练集分割为能够基本正确分类的子集。
决策树的剪枝:由于生成的决策树存在过拟合的问题,需要对它进行剪枝(考虑全局最优)。决策树的剪枝,往往从已生成的树上剪掉一些叶结点或叶结点以上的树,并将其父结点或根结点作为新的叶结点。
3.1 决策树的构造
创建分支的伪代码如下:
决策树的一般开发流程:
3.1.1 信息增益(Information)和熵(Entropy)
严格来说这个概念是信息论里的概念。(笔者之前学过通信原理所以相对熟悉一点)。
首先,明白一点。我们日常生活中会接收到无数的消息,但是只有那些你关心在意(或对你有用)的才叫做信息。(张三告诉我李四今天一只袜子丢了,事实上我也不认识李四,对于我来说这仅仅是一则消息,因为这和我没有太多关系,我也不必去在意)。
另外,想一下如何度量这种信息的大小或多少呢,香农说了,可以使用信息量来度量。在日常生活中,极少发生 的事件一旦发生是容易引起人们关注的(新闻说发生空难了,那必然会引起人们很大的关注,但事实是发生空难的概率很小很小),而 司空见惯的事不会引起注意 ,也就是说,极少见的事件所带来的信息量多。如果用统计学的术语来描述,就是出现概率小的事件信息量多。因此,事件出现得概率越小,信息量愈大。即信息量的多少是与事件发生频繁(即概率大小)成反比 。
设X是一个取有限个值的离散随机变量,X=xi (i=1,2…,n)。则信息量定义为
为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值:<对数以2和e为底时,熵的单位分别叫做比特(bit)和奈特(nat)>
程序清单3-1:计算给定数据集的香农熵
from math import log
def calcShanonEnt(dataSet):
numEntries = len(dataSet) #计算实例总数
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] #-1代表最后一列
# 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) #base = 2 求对数
return shannonEnt
def createDataSet():
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
输入:
myDat,labels = createDataSet()
myDat
输出:
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
输入:
calcShanonEnt(myDat)
输出:
0.9709505944546686
按照书上内容,在数据集中增加更多的分类,观察熵的变化。
输入:
myDat[0][-1]='maybe'
myDat
输出:
[[1, 1, 'maybe'], [1,