如果觉得本篇文章对您的学习起到帮助作用,请 点赞 + 关注 + 评论 ,留下您的足迹💪💪💪
本篇文章为我对机器学习实战-决策树的理解与我在学习时所做笔记,一是为了日后查找方便并加深对代码的理解,二是希望能帮助到使用这本书遇到困难的人。
代码可在python3.7跑通
因此代码相对原书做了一些修改,增加了可读性,同时也解决了一些问题。
代码及详细注释如下:
from math import log
import operator
import pickle
def calcShannonEnt(dataSet):
'''
:param dataSet:
:return:
'''
numDatas = len(dataSet)
labelsCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelsCounts.keys():
labelsCounts[currentLabel] = 0
# 书中这一行缩进错误
labelsCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelsCounts:
prob = float(labelsCounts[key]) / numDatas
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
'''
# 按照某个特征划分数据集时,就是把这个特征的全部元素提取出来
:param dataSet: 带划分数据集
:param axis: 带划分数据集的特征
:param value: 特征的返回值
:return:
'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
'''
# 该函数实现选取特征,划分数据集,并得到最好的特征,即熵最小的特征
:param dataSet: 1、数据集必须由列表元素组成的列表,列表元素具有相同的数据长度;
2、数据的最后一列或每个实例的最后一个元素是当前实例的类标签
:return: 返回最好的分类特征
'''
# dataSet数据熵
baseEntropy = calcShannonEnt(dataSet)
# 计算数据特征数目
numFeatures = len(dataSet[0]) - 1
# 定义信息增益
bestInfoGain = 0.0
# 定义最好的特征索引
bestFeature = -1
for i in range(numFeatures):
# # 等价写法
# featList = []
# for example in dataSet:
# featList.append(example[i])
# 提取特征上的所有取值
featList = [example[i] for example in dataSet]
# 集合set使特征取值唯一
uniqueVals = set(featList)
newEntropy = 0.0
# 遍历特征上所有取值,计算其条件熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# 信息增益
InfoGain = baseEntropy - newEntropy
if (InfoGain > bestInfoGain):
bestInfoGain = InfoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
'''
# 如果数据集处理了所有属性特征,但是仍然无法正确分类,则通过多数表决的方法定该叶子节点的类别
:param classList:
:return:
'''
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, label):
'''
# 构造决策树
:param dataSet:
:param labels:
:return:
'''
# 相当于复制一个列表,防止后面操作删除列表中内容,影响程序运行
labels =label[:]
classList = [example[-1] for example in dataSet]
# 递归函数第一个停止条件是所有类标签完全相同
if classList.count(classList[0]) == len(dataSet): # count() 方法用于统计某个元素在列表中出现的次数
return classList[0]
# 递归函数第二个停止条件是使用了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 选择信息增益最大的特征
bestFeature = chooseBestFeatureToSplit(dataSet)
# 信息增益最大特征的特征类别名称
bestFeatureLabel = labels[bestFeature]
# 构建树节点,树采取嵌套字典表示
myTree = {bestFeatureLabel:{}}
# 删除已经使用过的特征
del(labels[bestFeature])
# 得到所取特征的所有属性值
featValues = [example[bestFeature] for example in dataSet]
# 使属性值唯一化
uniqueValues = set(featValues)
for value in uniqueValues:
subLabels = labels[:]
myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
return myTree
def classify(inputTree,featLabels,testVec):
'''
:param inputTree: 决策树的字典模型
:param featLabels: 数据标签列表
:param testVec: 新数据的特征
:return:
'''
# 当前树的根节点特征名称
firstStr = list(inputTree.keys())[0]
# 根节点下的所有子节点
secondDict = inputTree[firstStr]
# index() 函数用于从列表中找出某个值第一个匹配项的索引位置
# 根节点特征对应的索引下标
featIndex = featLabels.index(firstStr)
# 待测试数据集特征值
key = testVec[featIndex]
valueOfFeat = secondDict[key]
# 判断valueOfFeat是字典类型,还是数值;若非字典类型,则说明该节点是叶子节点
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
# 此代码和上面完成的效果一样,上面更容易读懂
# def classify(inputTree,featLabels,testVec):
# # 当前树的根节点特征名称
# firstStr = list(inputTree.keys())[0]
# # 根节点下的所有子节点
# secondDict = inputTree[firstStr]
# # index() 函数用于从列表中找出某个值第一个匹配项的索引位置
# # 根节点特征对应的索引下标
#
# featIndex = featLabels.index(firstStr)
# for key in secondDict.keys():
# if testVec[featIndex] == key:
# if type(secondDict[key]).__name__ == 'dict':
# classLabel = classify(secondDict[key], featLabels, testVec)
# else:
# classLabel = secondDict[key]
#
# return classLabel
def storeTree(inputTree, filename):
'''
# pickle提供了一个简单的持久化功能。可以将对象以文件的形式存放在磁盘上。
# pickle模块只能在python中使用,python中几乎所有的数据类型(列表,字典,集合,类等)都可以用pickle来序列化,
# pickle序列化后的数据,可读性差,人一般无法识别。
:param inputTree: 决策树的树模型,数据类型为字典
:param filename: 保存的文件及其路径
:return:
'''
# wb 必须加才不会报错
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
'''
# 加载 pickle 保存的模型
# 添加 fr.close() 目的是将打开的文件关闭,防止内存溢出
:param filename:
:return:
'''
# rb 必须加才不会报错
fr = open(filename, 'rb')
treeData = pickle.load(fr)
fr.close()
return treeData
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfaces', 'flippers']
return dataSet, labels
def main():
# 使用 with 打开文件,可以在使用完毕后,python自动关闭文件,防止内存溢出
with open('dataset//lenses.txt') as fr:
# 此处为 '\t' 如果没有 '\t' 文件按照空格划分,将会多分出一个类别
lenses = [line.strip().split('\t') for line in fr.readlines()]
# fr = open('dataset//lenses.txt')
# lenses = [line.strip().split('\t') for line in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
if __name__ == '__main__':
main()
希望文章内容可以帮助到你,快来动手敲代码吧!!