1.决策树的基本原理与伪代码
决策树算法,是一种监督学习的分类算法,可细分为ID3、C4.5、CART等三种算法,前两种适用于标称型数据,后一种适用于数值型数据。
1.1决策树的基本原理:
所谓决策树,即根据样本数据集的不同特征不断对数据集进行划分,划分的最终结果构成一棵树。
该算法的难点在于:在众多特征中,最先选择哪一个特征对数据集进行划分?
ID3算法采用信息增益;C4.5算法采用信息增益率;CART算法采用基尼系数
本文主要介绍ID3算法,即以数据划分前后的信息增益为指标进行特征选择。
香农熵的计算公式:
信息增益=划分前的香农熵—划分后的条件熵
1.2决策树算法的伪代码:
createbranch():
if 所有样本数据的标签均一致
返回该标签
else 寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分后的子集
调用函数createbranch并增加返回结果到分支节点中
return 分支节点
2.决策树算法的优缺点
优点:形象直观,易于理解,复杂度不高,可将分类器存储在硬盘上,不用每次都重新学习
缺点:容易过拟合,需要剪枝(本文先不讨论)
3.该算法的PYTHON语言实现
3.1决策树分类算法的主体编程架构
一个完整的决策树分类算法主要由以下几块构成:
- 构建决策树,重点在于选择划分数据集的最好特征
- 存储决策树并使用该决策树对测试数据进行分类
- 使用matplotlib绘制决策树
3.2决策树-ID3算法的PYTHON代码
3.2.1构建决策树
这一块主要由以下几个子函数构成:
- 计算香农熵
- 根据特征值划分数据集
- 选择划分数据集的最好特征
- (主体函数)构建树
首先,来计算香农熵:
#计算香农熵
def calcShannonEnt(dataSet):
# 数据集的数据个数
numEntries = len(dataSet)
# 建立一个标签字典,键值是每一行的标签,对应值是该标签出现的次数
labelCounts = {}
for featVec in dataSet: # 遍历数据集中的每一行
currentLabel = featVec[-1] # currentlable存放的是当前一行的标签
# 如果当前标签不在之前的标签字典里,将标签字典里当前标签对应的值赋0
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #统计当前标签出现的次数
# 计算当前数据集的香农熵并返回
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries #prob为当前标签在所有标签里出现的概率
shannonEnt -= prob * math.log(prob,2)
return shannonEnt
其次,构建一个子函数,其功能是根据给定的特征值来划分数据集:
# ==============================================
# 输入:
# dataSet: 训练集文件名(含路径)
# axis: 选定的特征所在列数
# value: 选定的特征值
# 输出:
# retDataSet: 划分后的子列表
# ==============================================
#函数功能:找出所有行中第axis个元素值为value的行,去掉该元素,返回对应行矩阵
def splitDataSet(dataSet, axis, value):
retDataSet = [] # 存放划分后的子列表
for featVec in dataSet: # 逐行遍历数据集
if featVec[axis] == value: # 如果当前行第axis列的特征值等于value
reducedFeatVec = featVec[:axis] # 抽取掉数据集中的目标特征值列
reducedFeatVec.extend(featVec[axis+1:])
# 将抽取后的数据加入到划分结果列表中
retDataSet.append(reducedFeatVec)
return retDataSet
注意:
- a=b[:axis],该语句提取出b列表的前0到(axis-1)列赋值给a,a.extend(b[axis+1:]),该语句在a列表后面加上b列表的(axis+1)到最后列。因此该两句联合起来就是删除b列表中axis列的值。
- append与extend的区别:
extend是将两个列表相连,append是将新列表整体作为一个对象一个元素添加到旧列表中
list_extend = ['a', 'b', 'c'],list_extend.extend(['d', 'e', 'f']),print("list_extend:%s" %list_extend)
# 输出结果:list_extend:['a', 'b', 'c', 'd', 'e', 'f']
list_append = ['a', 'b', 'c'],list_append.append(