机器学习——CART算法

本文介绍了如何利用CART(ClassificationandRegressionTrees)算法对天气数据进行预测,包括计算Shannon熵、数据集划分、选择最佳特征分割数据集以及创建决策树的过程。
摘要由CSDN通过智能技术生成

第1关:利用CART算法预测天气

from math import log
import operator


def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  # 计算数据集中的实例总数
    labelCounts = {}
    # 统计类别出现的次数
    # 放到一个数组中 key表示标签,val表示个数
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


def splitDataSet(dataSet, axis, value):
    """
    输入:数据集,选择维度,选择值
    输出:划分数据集
    描述:按照给定特征划分数据集;去除选择维度中等于选择值的项
    """
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    """
    输入:数据集
    输出:最好的划分维度
    描述:选择最好的数据集划分维度
    """
    numFeatures = len(dataSet[0]) - 1  # 特征个数
    bestGini = 999999.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 统计第i个特征有几种情况
        uniqueVals = set(featList)
        gini = 0.0
        #### 请补充完整代码 ####
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            subProb = len(splitDataSet(subDataSet, -1, 'N')) / float(len(subDataSet))
            gini += prob * (1.0 - pow(subProb, 2) - pow(1 - subProb, 2))
        if gini < bestGini:
            bestGini = gini
            bestFeature = i
        #######################
    return bestFeature


def majorityCnt(classList):
    """
    输入:分类类别列表
    输出:子结点的分类
    描述:数据集已经处理了所有属性,但是类标签依然不是唯一的,
          采用多数判决的方法决定该子结点的分类
    """
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    """
    输入:数据集,特征标签
    输出:决策树
    描述:递归构建决策树,利用上述的函数
    """
    # 递归停止条件
    classList = [example[-1] for example in dataSet]  # 取所有的类别
    if classList.count(classList[0]) == len(classList):
        # 只有一个类别时,停止划分
        return classList[0]
    if len(dataSet[0]) == 1:
        # 遍历完所有特征时返回出现次数最多的
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 返回最好的特征
    bestFeatLabel = labels[bestFeat]  # 提取特征名称
    myTree = {bestFeatLabel: {}}  # 特征对应的字典
    del (labels[bestFeat])
    # 得到列表包括结点所有的属性值
    featValues = [example[bestFeat] for example in dataSet]  # 取最优特征的种类
    uniqueVals = set(featValues)
    #### 请补充完整代码 ####
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    #######################
    return myTree


def classify(inputTree, featLabels, testVec):
    """
    输入:决策树,分类标签,测试数据
    输出:决策结果
    描述:跑决策树
    """
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    classLabel = 'N'
    #### 请补充完整代码 ####
    for i in secondDict.keys():
        if testVec[featIndex] == i:
            if type(secondDict[i]).__name__ == 'dict':
                classLabel = classify(secondDict[i], featLabels, testVec)
            else:
                classLabel = secondDict[i]
    #######################
    return classLabel


def classifyAll(inputTree, featLabels, testDataSet):
    """
    输入:决策树,分类标签,测试数据集
    输出:决策结果
    描述:跑决策树
    """
    classLabelAll = []
    for testVec in testDataSet:
        classLabelAll.append(classify(inputTree, featLabels, testVec))
    return classLabelAll


def storeTree(inputTree, filename):
    """
    输入:决策树,保存文件路径
    输出:
    描述:保存决策树到文件
    """

    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    """
    输入:文件路径名
    输出:决策树
    描述:从文件读取决策树
    """

    fr = open(filename, 'rb')
    return pickle.load(fr)


def createDataSet():
    """
    outlook->  0: sunny | 1: overcast | 2: rain
    temperature-> 0: hot | 1: mild | 2: cool
    humidity-> 0: high | 1: normal
    windy-> 0: false | 1: true
    """
    dataSet = [[0, 0, 0, 0, 'N'],
               [0, 0, 0, 1, 'N'],
               [1, 0, 0, 0, 'Y'],
               [2, 1, 0, 0, 'Y'],
               [2, 2, 1, 0, 'Y'],
               [2, 2, 1, 1, 'N'],
               [1, 2, 1, 1, 'Y']]
    labels = ['outlook', 'temperature', 'humidity', 'windy']
    return dataSet, labels


def createTestSet():
    """
    outlook->  0: sunny | 1: overcast | 2: rain
    temperature-> 0: hot | 1: mild | 2: cool
    humidity-> 0: high | 1: normal
    windy-> 0: false | 1: true
    """
    testSet = [[0, 1, 0, 0],
               [0, 2, 1, 0],
               [2, 1, 1, 0],
               [0, 1, 1, 1],
               [1, 1, 0, 1],
               [1, 0, 1, 0],
               [2, 1, 0, 1]]
    return testSet


def main():
    dataSet, labels = createDataSet()
    labels_tmp = labels[:]  # 拷贝,createTree会改变labels
    desicionTree = createTree(dataSet, labels_tmp)
    # storeTree(desicionTree, 'classifierStorage.txt')
    # desicionTree = grabTree('classifierStorage.txt')
    testSet = createTestSet()
    print(classifyAll(desicionTree, labels, testSet))  # 注意这里


if __name__ == '__main__':
    main()
  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FindYou.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值