机器学习(二):决策树

1.决策树

2.决策树的优缺点

优点:计算复杂度不高,输出结果易于理解,对于中间值的缺失不敏感,可以处理相关特性数据。
缺点:可能会产生过度匹配问题
适用数据类型:数值型和对称型
3.计算公式
  1. 香农熵:在这里插入图片描述

  2. 信息增益:Gain(S, A)=E(S)–E(S, A) (总香农熵-类别香农熵)

  3. 增益率:

  4. 基尼指数:
    (3,4公式本次代码未涉及)

4.构造过程
  1. 计算给定数据集的香农熵。
  2. 划分数据集。
  3. 选择最好的数据集划分。
  4. 递归创建决策树。
  5. 使用决策树分类函数。
  6. 存储决策树。
5.代码构建
from math import *
import operator
import pickle


def createDataSet():
    dataSet = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no']
    ]
    labels = ['no sirfacing', 'flippers']
    return dataSet, labels


def calcShannonEnt(dataSet):
    """
    计算香农熵
    :param dataSet:
    :return:
    """
    numEntries = len(dataSet)
    # key 标签
    labelCounts = {}
    # 计算 每个标签的数量
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
        # 香农熵
        shannonEnt = 0.0

        # 计算香农熵
        for key in labelCounts:
            prob = float(labelCounts[key]) / numEntries
            # 计算公式
            shannonEnt -= prob * log(prob, 2)
    return shannonEnt


def splitDataSet(dataSet, axis, value):
    """
    划分数据集,去除当前 axis位置的 value 值的数据集
    :param dataSet:
    :param axis:
    :param value:
    :return:
    """
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet


def chooseBeastFeatureToSplit(dataSet):
    """
    计算出最高的信息增益属性
    :param dataSet:
    :return: 最优属性的下标
    """
    # 目标属性长度
    numFeatures = len(dataSet[0]) - 1
    # 计算整体的香农熵
    baseEntropy = calcShannonEnt(dataSet)
    # 最优的信息增益
    bestInfoGain = 0.0
    # 最优属性索引
    bestFeature = -1

    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]

        uniqueVals = set(featList)
        # 当前属性的香农熵
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        # 选择最优的信息增益,记录下标
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i

    return bestFeature


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        classCount[vote] = classCount.get(vote, 0) + 1
        # if vote not in classCount.keys():classCount[vote] = 0
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    """
    创建决策树
    :param dataSet:
    :param labels:
    :return:
    """
    classList = [example[-1] for example in dataSet]
    # 类别完全相同则停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    # 特征遍历结束,返回出现次数最多的特征
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # 选出最优的信息增益
    bestFeat = chooseBeastFeatureToSplit(dataSet)

    bestFeatLabel = labels[bestFeat]

    myTree = {bestFeatLabel: {}}

    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    for value in uniqueVals:

        subLabels = labels[:]
        myTree[bestFeatLabel][value] = \
            createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

    return myTree


def classify(inputTree, featlabels, testVec):
    """
    分类测试
    :param inputTree:
    :param featlabels:
    :param testVec:
    :return:
    """
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featlabels.index(firstStr)

    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featlabels, testVec)
            else:
                classLabel = secondDict[key]

    return classLabel


def storeTree(inputTree, filename):
    """
    保存
    :param inputTree: 决策树生成的节点数据
    :param filename:
    :return:
    """
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    """
    读取决策树的节点数据
    :param filename:
    :return:
    """
    fr = open(filename)
    return pickle.load(fr)


if __name__ == '__main__':
    myData, labels = createDataSet()
    label = labels[:]
    print(labels)
    myTree = createTree(myData, labels)
    print(myTree)
    print(classify(myTree, label, [1, 0]))
    print(label)

总结:

决策树的算法有很多例如:CLS,ID3,C4.5,SLIQ,CART。
本文章的代码时使用了信息增益的变量所以是属于C4.5的算法。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值