《机器学习实战》读书笔记4:决策树源码分析

本文对《机器学习实战》第三章——决策树的源码进行了全面的分析和解释。由于个人觉得作者的代码变量命名具有一定的迷惑性,使读者容易混淆,所以部分代码可能作了修改。

本文只包含了构建决策树、用决策树分类、序列化决策树的代码。不包括画图的代码


程序清单3-0:创建简单的数据集

这部分是书上的python交互命令创建数据集的代码,写成函数,不用每次都输入一长传命令:

def createDataSet():
    dataSet = [[1, 1, 'yes'], # 创建特征数量为2的数据集
               [1, 1, 'yes'], 
               [1, 0, 'no'], 
               [0, 1, 'no'], 
               [0, 1, 'no']]
    featureNames = ['no surfacing', 'flippers'] # 特征名列表
    return dataSet, featureNames

程序清单3-1:计算香浓熵

关于香浓熵,可以参考我的另外一片文章:《机器学习实战》读书笔记3:信息熵和信息增益

def calcShannonEnt(dataSet): # 参数:数据集
    numEntries = len(dataSet) # 获得数据集中样本的数量
    labelCounts = {} # 用于保存各个分类标签出现的次数
    for featureVec in dataSet: # 遍历每个样本
        currentLabel = featureVec[-1] # 获得样本的分类标签
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1 # 对应标签数量加一
    shannonEnt = 0.0 # 初始化香浓熵为0
    for key in labelCounts: # 遍历每个类别标签
        prob = float(labelCounts[key]) / numEntries # 计算该类别在所有样本中的比例(即出现的概率)
        shannonEnt -= prob * log(prob, 2) # 累加每个类别的信息量,计算平均信息量,即香浓熵
    return shannonEnt

程序清单3-2:按照给定特征划分数据集

def splitDataSet(dataSet, fIndex, fValue): # 参数:待划分数据集合、特征下标、特征值
    retDataSet = [] # 保存划分出的数据子集
    for featureVec in dataSet: # 遍历数据集中的每个样本
        if featureVec[fIndex] == fValue: # 如果特征值符合要求,则添加到子集中
            reducedFeatureVec = featureVec[:fIndex] # 保存第0到第fIndex-1个特征
            reducedFeatureVec.extend(featureVec[fIndex+1:]) # 保存第fIndex+1个到最后一个特征
            retDataSet.append(reducedFeatureVec) # 添加符合要求的样本到划分子集中
    return retDataSet # 返回划分好(特征fIndex的值=fValue)的子集

注意上面的倒数第四行和倒数第三行代码,根据特征 fIndex 划分后,必须删掉样本中的这个特征。因为 ID3 算法在创建决策树的过程中要消耗特征用于创建判断节点。


程序清单3-3:选择最好的数据集划分方式

关于信息增益,可以参考我的另外一篇文章:《机器学习实战》读书笔记3:信息熵和信息增益

def chooseBestFeature(dataSet): # 参数:数据集
    numEntries = len(dataSet) # 数据集样本数量
    numFeatures = len(dataSet[0]) - 1 # 数据集特征个数
    baseEntropy = calcShannonEnt(dataSet) # 计算划分前的熵
    bestInfoGain = 0.0 # 信息增益
    bestFeature = -1 # 用于划分数据集的最佳特征
    for i in range(numFeatures): # 遍历每个特征
        featureValsList = [sample[i] for sample in dataSet] # 获得特征i的所有值
        uniqueVals = set(featureValsList) # 去除特征i中重复的值
        newEntropy = 0.0 # 保存划分后的信息熵
        for value in uniqueVals: # 遍历特征i的每个值
            subDataSet = splitDataSet(dataSet, i, value) # 根据特征i的值value划分数据集
            prob = len(subDataSet) / float(numEntries) # 特征i的值为value的样本所占总样本的比例(概率)
            newEntropy += prob * calcShannonEnt(subDataSet) # 累加计算根据特征i划分后的熵
        infoGain = baseEntropy - newEntropy # 信息增益
        if(infoGain > bestInfoGain): # 选择使信息增益最大的特征作为最佳特征
            bestInfoGain = infoGain # 更新最大信息增益
            bestFeature = i # 更新最佳特征的下标
    return bestFeature # 返回最佳特征的下标

注意: =


程序清单3-3.5:多数表决函数

3-3.5是我自己取的。:)。

def majorityCnt(classList): 参数:数据集对应的类别列表
    classCount = {} # 类别数量统计
    for c in classList: # 遍历类别列表
        classCount[c] = classCount.get(c, 0) + 1 # 计数
    # 排序,sorted默认升序,所以要反转一下顺序
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return classCount[0][0] # 返回出现次数最多的类别

程序清单3-4:创建树的函数代码

def createTree(dataSet, fNames): # 参数:数据集、数据集特征名列表
    classList = [sample[-1] for sample in dataSet] # 类别标签列表
    if classList.count(classList[0]) == len(classList): # 如果只剩同一类,返回该类标签,停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1: # 消耗完所有特征时,返回数据集中出现次数最多的类别标签
        return majorityCnt(classList)
    bestFeature = chooseBestFeature(dataSet) # 获得最佳特征的下标
    bestFeatureName = fNames[bestFeature] # 获得最佳特征的名字
    myTree = {bestFeatureName:{}} # 以最佳特征为根节点创建子树
    del(fNames[bestFeature]) # 在特征名列表中删除最佳特征(创建节点需要消耗特征)
    featureValsList = [sample[bestFeature] for sample in dataSet] # 获得最佳特征的所有可能值
    uniqueVals = set(featureValsList) # 去除重复的值
    for value in uniqueVals: # 遍历每个可能的值
        subFNames = fNames[:] # 深拷贝一份特征名列表
        # 递归创建决策树
        myTree[bestFeatureName][value] = createTree(splitDataSet(dataSet, bestFeature, value), subFNames)
    return myTree # 返回决策树

注意:代码倒数第三行进行深拷贝的原因是 python 列表是按照引用传递的。


程序清单3-8:决策树的分类函数

这个函数最好对应一棵决策树来理解。

def classify(inputTree, fNames, testVec):
    featureName = inputTree.keys()[0] # 获得根节点名(特征名)
    secondDict = inputTree[featureName] # 获得根节点(代表的特征)的所有特征值
    fIndex = fNames.index(featureName) # 获得根节点名(特征名)在数据集中的下标
    for fValue in secondDict.keys(): # 遍历根节点(代表的特征)的所有可能值
        if testVec[fIndex] == fValue: # 如果测试样本对应根节点特征的值等于fValue,进入根节点的值为fValue的分支
            if type(secondDict[fValue]).__name__ == 'dict': # 如果分支节点是一个字典,说明不是一个叶子节点
                # 进入决策树的下一层,参数:分支子树、特征名列表、测试向量
                classLabel = classify(secondDict[fValue], fNames, testVec)
            else: # 否则为叶子节点
                classLabel = secondDict[fValue] # 保存分类标签
            break # 此处增添break,避免不必要的遍历
    return classLabel # 返回测试向量对应的分类标签

程序清单3-9:使用pickle模块序列化存储决策树

def storeTree(inputTree, filename):
    import pickle # 导入pickle模块
    fw = open(filename, 'w') # 按写模式打开文件filename
    pickle.dump(inputTree, fw) # 序列化决策树到文件
    fw.close() # 关闭文件

def grabTree(filename):
    import pickle # 导入pickle模块
    fr = open(filename) # 打开文件filename,默认模式为读
    return pickle.load(fr) # 从文件加载决策树

完整代码

from math import log
import operator

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featureVec in dataSet:
        currentLabel = featureVec[-1]
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


def createDataSet():
    dataSet = [[1, 1, 'yes'], 
               [1, 1, 'yes'], 
               [1, 0, 'no'], 
               [0, 1, 'no'], 
               [0, 1, 'no']]
    featureNames = ['no surfacing', 'flippers']
    return dataSet, featureNames


def splitDataSet(dataSet, fIndex, fValue):
    retDataSet = []
    for featureVec in dataSet:
        if featureVec[fIndex] == fValue:
            reducedFeatureVec = featureVec[:fIndex]
            reducedFeatureVec.extend(featureVec[fIndex+1:])
            retDataSet.append(reducedFeatureVec)
    return retDataSet


def chooseBestFeature(dataSet):
    numEntries = len(dataSet)
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featureValsList = [sample[i] for sample in dataSet]
        uniqueVals = set(featureValsList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(numEntries)
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {}
    for c in classList:
        classCount[c] = classCount.get(c, 0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return classCount[0][0]


def createTree(dataSet, fNames):
    classList = [sample[-1] for sample in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeature = chooseBestFeature(dataSet)
    bestFeatureName = fNames[bestFeature]
    myTree = {bestFeatureName:{}}
    del(fNames[bestFeature])
    featureValsList = [sample[bestFeature] for sample in dataSet]
    uniqueVals = set(featureValsList)
    for value in uniqueVals:
        subFNames = fNames[:]
        myTree[bestFeatureName][value] = createTree(splitDataSet(dataSet, bestFeature, value), subFNames)
    return myTree


def classify(inputTree, fNames, testVec):
    featureName = inputTree.keys()[0]
    secondDict = inputTree[featureName]
    fIndex = fNames.index(featureName)
    for fValue in secondDict.keys():
        if testVec[fIndex] == fValue:
            if type(secondDict[fValue]).__name__ == 'dict':
                classLabel = classify(secondDict[fValue], fNames, testVec)
            else:
                classLabel = secondDict[fValue]
            break
    return classLabel


def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

希望能帮到大家。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值