机器学习实战(python3.7)-决策树

如果觉得本篇文章对您的学习起到帮助作用,请 点赞 + 关注 + 评论 ,留下您的足迹💪💪💪

本篇文章为我对机器学习实战-决策树的理解与我在学习时所做笔记,一是为了日后查找方便并加深对代码的理解,二是希望能帮助到使用这本书遇到困难的人。

代码可在python3.7跑通
因此代码相对原书做了一些修改,增加了可读性,同时也解决了一些问题。

代码及详细注释如下:

from math import log
import operator
import pickle

def calcShannonEnt(dataSet):
    '''

    :param dataSet:
    :return:
    '''

    numDatas = len(dataSet)
    labelsCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelsCounts.keys():
            labelsCounts[currentLabel] = 0
        # 书中这一行缩进错误
        labelsCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelsCounts:
        prob = float(labelsCounts[key]) / numDatas
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt



def splitDataSet(dataSet, axis, value):
    '''
    # 按照某个特征划分数据集时,就是把这个特征的全部元素提取出来
    :param dataSet: 带划分数据集
    :param axis: 带划分数据集的特征
    :param value: 特征的返回值
    :return:
    '''
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    '''
    # 该函数实现选取特征,划分数据集,并得到最好的特征,即熵最小的特征
    :param dataSet: 1、数据集必须由列表元素组成的列表,列表元素具有相同的数据长度;
                    2、数据的最后一列或每个实例的最后一个元素是当前实例的类标签
    :return: 返回最好的分类特征
    '''
    # dataSet数据熵
    baseEntropy = calcShannonEnt(dataSet)
    # 计算数据特征数目
    numFeatures = len(dataSet[0]) - 1
    # 定义信息增益
    bestInfoGain = 0.0
    # 定义最好的特征索引
    bestFeature = -1

    for i in range(numFeatures):
        # # 等价写法
        # featList = []
        # for example in dataSet:
        #     featList.append(example[i])

        # 提取特征上的所有取值
        featList = [example[i] for example in dataSet]
        # 集合set使特征取值唯一
        uniqueVals = set(featList)
        newEntropy = 0.0

        # 遍历特征上所有取值,计算其条件熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)

        # 信息增益
        InfoGain = baseEntropy - newEntropy

        if (InfoGain > bestInfoGain):
            bestInfoGain = InfoGain
            bestFeature = i

    return  bestFeature


def majorityCnt(classList):
    '''
    # 如果数据集处理了所有属性特征,但是仍然无法正确分类,则通过多数表决的方法定该叶子节点的类别
    :param classList:
    :return:
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet, label):
    '''
    # 构造决策树
    :param dataSet:
    :param labels:
    :return:
    '''
    # 相当于复制一个列表,防止后面操作删除列表中内容,影响程序运行
    labels =label[:]
    classList = [example[-1] for example in dataSet]
    # 递归函数第一个停止条件是所有类标签完全相同
    if classList.count(classList[0]) == len(dataSet): # count() 方法用于统计某个元素在列表中出现的次数
        return classList[0]
    # 递归函数第二个停止条件是使用了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # 选择信息增益最大的特征
    bestFeature = chooseBestFeatureToSplit(dataSet)
    # 信息增益最大特征的特征类别名称
    bestFeatureLabel = labels[bestFeature]
    # 构建树节点,树采取嵌套字典表示
    myTree = {bestFeatureLabel:{}}
    # 删除已经使用过的特征
    del(labels[bestFeature])
    # 得到所取特征的所有属性值
    featValues = [example[bestFeature] for example in dataSet]
    # 使属性值唯一化
    uniqueValues = set(featValues)
    for value in uniqueValues:
        subLabels = labels[:]
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
    return myTree



def classify(inputTree,featLabels,testVec):
    '''
    :param inputTree: 决策树的字典模型
    :param featLabels: 数据标签列表
    :param testVec: 新数据的特征
    :return:
    '''
    # 当前树的根节点特征名称
    firstStr = list(inputTree.keys())[0]
    # 根节点下的所有子节点
    secondDict = inputTree[firstStr]
    # index() 函数用于从列表中找出某个值第一个匹配项的索引位置
    # 根节点特征对应的索引下标
    featIndex = featLabels.index(firstStr)
    # 待测试数据集特征值
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    # 判断valueOfFeat是字典类型,还是数值;若非字典类型,则说明该节点是叶子节点
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat
    return classLabel

# 此代码和上面完成的效果一样,上面更容易读懂
# def classify(inputTree,featLabels,testVec):
#     # 当前树的根节点特征名称
#     firstStr = list(inputTree.keys())[0]
#     # 根节点下的所有子节点
#     secondDict = inputTree[firstStr]
#     # index() 函数用于从列表中找出某个值第一个匹配项的索引位置
#     # 根节点特征对应的索引下标
#
#     featIndex = featLabels.index(firstStr)
#     for key in secondDict.keys():
#         if testVec[featIndex] == key:
#             if type(secondDict[key]).__name__ == 'dict':
#                 classLabel = classify(secondDict[key], featLabels, testVec)
#             else:
#                 classLabel = secondDict[key]
#
#     return classLabel

def storeTree(inputTree, filename):
    '''
    # pickle提供了一个简单的持久化功能。可以将对象以文件的形式存放在磁盘上。
    # pickle模块只能在python中使用,python中几乎所有的数据类型(列表,字典,集合,类等)都可以用pickle来序列化,
    # pickle序列化后的数据,可读性差,人一般无法识别。
    :param inputTree: 决策树的树模型,数据类型为字典
    :param filename: 保存的文件及其路径
    :return:
    '''
    # wb 必须加才不会报错
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    '''
    # 加载 pickle 保存的模型
    # 添加 fr.close() 目的是将打开的文件关闭,防止内存溢出
    :param filename:
    :return:
    '''
    # rb 必须加才不会报错
    fr = open(filename, 'rb')
    treeData = pickle.load(fr)
    fr.close()
    return treeData



def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfaces', 'flippers']
    return dataSet, labels

def main():
    # 使用 with 打开文件,可以在使用完毕后,python自动关闭文件,防止内存溢出
    with open('dataset//lenses.txt') as fr:
        # 此处为 '\t' 如果没有 '\t' 文件按照空格划分,将会多分出一个类别
        lenses = [line.strip().split('\t') for line in fr.readlines()]
    # fr = open('dataset//lenses.txt')
    # lenses = [line.strip().split('\t') for line in fr.readlines()]

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(lenses, lensesLabels)
    print(lensesTree)



if __name__ == '__main__':
    main()

希望文章内容可以帮助到你,快来动手敲代码吧!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值