机器学习算法——决策树算法(ID3算法划分数据集,基于香农熵的python底层实现)

决策树算法是一种非参数的决策算法,它根据数据的不同特征进行多层次的分类和判断,最终决策出所需要预测的结果。它既可以解决分类算法,也可以解决回归问题,具有很好的解释能力。

部分图片源自网络,侵删
在这里插入图片描述
决策树就如上图所示,决策树算法能够读取数据集合,构建类似于上图的决策树。
决策树的一个重要任务是为了厘清数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,在这些机器根据数据集创建规则时,就是机器学习的学习过程。
传统的专家系统中经常使用决策树,而且决策树给出的结果往往可以匹敌在当前领域具有几十年工作经验的人类专家。

1、决策树

  • 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
  • 缺点:可能会产生过度匹配问题
  • 适用数据类型:数值型和标称型

在构建决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测试之后,原始数据就被分类为几个数据子集。这些子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一种类型,则当前无需考虑的数据已经正确地划分数据分类,无需进一步对数据集进行分割。如果数据子集内的数据不属于同一类,则需要重新划分数据子集的过程。划分数据子集的算法和划分原始数据集的方法相同,直到所有具有相同类型的数据均在一个数据子集内。
创建分支的伪代码如下所示:

def createBranch():
	if 检测数据集中的每个子项属于同一分类:
		return 类标签
	else:
		寻找划分数据集的最好特征
		划分数据集
		创建分支节点
			for 每个划分的子集
				调用createBranch()并返回结果到分支节点中
		return 分支节点

2、决策树的一般流程

  1. 收集数据:可以使用任何合适的方法
  2. 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
  3. 分析数据:可以使用任何合适方法,构造树完成之后,我们应该检查图形是否符合预期
  4. 训练数据:构造树的数据结构
  5. 测试数据:使用经验树计算错误率
  6. 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好理解数据的内在含义

3、信息增益

划分数据集的最大原则是将无序的数据变得有序。我们可以使用多种方法划分数据集,此处我们选择熵来衡量数据的无序程度。
组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支学科,我们可以在划分数据前后使用信息论量化度量信息的内容。
在划分数据集前后信息发生的变化称为信息增益,获得信息增益最高的特征就是最好的选择。
集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德.香农。
熵是用来度量混乱程度的。比如下图,单词“Entropy”(熵的英文单词)可见的时候,熵最小,级有序程度最高,最有秩序。当被打乱的时候,由于熵增原理,熵开始增加,直到完全混乱:
在这里插入图片描述

熵在信息论中代表随机变量不确定度的度量

熵定义为信息的期望值,先来看看什么是信息。
如果待分类的事物可能划分在多个分类中,则符号Xi的信息定义为
L(Xi) = -log2p(Xi)

为了计算信息熵,我们需要计算所有类别可能值包含的信息期望值(即中学学的数学期望),下面是公式:
p(Xi)是选择该分类的概率,后面可以结合代码理解。

在这里插入图片描述

from math import log
def create_data_set():
    data_set = [[1, 1, 'yes'],
                [1, 1, 'yes'],
                [1, 0, 'no'],
                [0, 1, 'no'],
                [0, 1, 'no']]
    # dataSet最后一列指该生物是否属于鱼类
    # no surfacing指不浮出水面是否可以生存
    # flippers指是否有脚蹼
    labels = ['no surfacing', 'flippers']
    # change to discrete values
    return data_set, labels


def compute_shannon_etropy(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:  # the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannon_entropy = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannon_entropy -= prob * log(prob, 2)  # log base 2
    return shannon_entropy

4、划分数据集

分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断当前是否正确地分类了数据集。我们需要对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。

划分数据集的函数

def split_data_set(dataset, axis, value):
    retDataSet = []
    for featVec in dataset:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

选择最好的数据集划分方式

def choose_best_feature_to_split(dataset):
    numFeatures = len(dataset[0]) - 1  # 数据集最后一列作为标签
    baseEntropy = compute_shannon_etropy(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):  # 遍历所有的特征
        featList = [example[i] for example in dataset]  # create a list of all the examples of this feature
        uniqueVals = set(featList) # 利用集合函数得到不重复的值
        newEntropy = 0.0
        for value in uniqueVals:
            sub_data_set = split_data_set(dataset, i, value)
            prob = len(sub_data_set) / float(len(dataset))
            newEntropy += prob * compute_shannon_etropy(sub_data_set)
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        if (infoGain > bestInfoGain):  # 比较目前最好的信息增益
            bestInfoGain = infoGain  # 更新目前最好的信息增益
            bestFeature = i
    return bestFeature  # 返回在当前子集中可用于划分的最好的特征序号

5、递归构造决策树

根据前面所述的伪代码,我们知道,递归结束的条件是:程序遍历完成所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分支。如果所有实例具有相同的分类,则得到一个叶子节点或者或者终止块。任何到达叶子节点的数据必须属于叶子节点的分类。

伪代码:
def createBranch():
	if 检测数据集中的每个子项属于同一分类:
		return 类标签
	else:
		寻找划分数据集的最好特征
		划分数据集
		创建分支节点
			for 每个划分的子集
				调用createBranch()并返回结果到分支节点中
		return 分支节点

算法开始前会计算所有可用特征的数目,如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子节点的分类。
表决操作代码

import operator
# 用于操作键值排序字典
def majority_vote_class(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

创建树的函数代码
递归函数的第一个停止条件是所有的类标签完全相同,则直接返回该标签。递归函数的第二个停止条件是使用完了所有特征,仍不能将数据集划分为仅包含唯一类别的分组。由于第二个条件无法简单地返回唯一的类标签,这里使用程序投票函数挑选出次数最多的类别作为返回值。

def create_tree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]  # stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet
        return majority_vote_class(classList)
    bestFeat = choose_best_feature_to_split(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = create_tree(split_data_set(dataSet, bestFeat, value), subLabels)
    return myTree

此处使用的是python的字典类型存储树的信息。字典变量myTree存储了树的所有信息,这对于其后绘制图形树很重要。

简单测试一下

myDat, labels = create_data_set()
myTree = create_tree(myDat, labels)
print(myTree)

"""
结果如下:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""

6、利用matplotlib注解功能绘制树形图

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def get_num_leaves(myTree):
    numLeafs = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += get_num_leaves(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def get_tree_depth(myTree):
    maxDepth = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + get_tree_depth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def plot_nodes(nodeTxt, centerPt, parentPt, nodeType):
    create_plots.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                              xytext=centerPt, textcoords='axes fraction',
                              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plot_middle_text(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    create_plots.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plot_tree(myTree, parentPt, nodeTxt):
    numLeafs = get_num_leaves(myTree)
    depth = get_tree_depth(myTree)
    firstStr = list(myTree)[0]
    cntrPt = (plot_tree.xOff + (1.0 + float(numLeafs)) / 2.0 / plot_tree.totalW, plot_tree.yOff)
    plot_middle_text(cntrPt, parentPt, nodeTxt)
    plot_nodes(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plot_tree.yOff = plot_tree.yOff - 1.0 / plot_tree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plot_tree(secondDict[key], cntrPt, str(key))
        else:
            plot_tree.xOff = plot_tree.xOff + 1.0 / plot_tree.totalW
            plot_nodes(secondDict[key], (plot_tree.xOff, plot_tree.yOff), cntrPt, leafNode)
            plot_middle_text((plot_tree.xOff, plot_tree.yOff), cntrPt, str(key))
    plot_tree.yOff = plot_tree.yOff + 1.0 / plot_tree.totalD


def create_plots(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    create_plots.ax1 = plt.subplot(111, frameon=False, **axprops)
    plot_tree.totalW = float(get_num_leaves(inTree))
    plot_tree.totalD = float(get_tree_depth(inTree))
    plot_tree.xOff = -0.5 / plot_tree.totalW;
    plot_tree.yOff = 1.0
    plot_tree(inTree, (0.5, 1.0), '')
    plt.show()


def return_tree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                   ]
    return listOfTrees[i]

7、绘制完整的决策树

我们利用从UCI上下载的佩戴隐形眼镜的条件。
将下面的文本复制,保存为glasses.txt即可

young	myope	no	reduced	no lenses
young	myope	no	normal	soft
young	myope	yes	reduced	no lenses
young	myope	yes	normal	hard
young	hyper	no	reduced	no lenses
young	hyper	no	normal	soft
young	hyper	yes	reduced	no lenses
young	hyper	yes	normal	hard
pre	myope	no	reduced	no lenses
pre	myope	no	normal	soft
pre	myope	yes	reduced	no lenses
pre	myope	yes	normal	hard
pre	hyper	no	reduced	no lenses
pre	hyper	no	normal	soft
pre	hyper	yes	reduced	no lenses
pre	hyper	yes	normal	no lenses
presbyopic	myope	no	reduced	no lenses
presbyopic	myope	no	normal	no lenses
presbyopic	myope	yes	reduced	no lenses
presbyopic	myope	yes	normal	hard
presbyopic	hyper	no	reduced	no lenses
presbyopic	hyper	no	normal	soft
presbyopic	hyper	yes	reduced	no lenses
presbyopic	hyper	yes	normal	no lenses

测试代码:

import getTree
import plotTree

fr = open("glasses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = getTree.create_tree(lenses, lensesLabels)
plotTree.create_plots(lensesTree)

绘制出如下图所示的决策树:
在这里插入图片描述

其中getTree.py完整代码如下:

from math import log2
import operator


def create_data_set():
    data_set = [[1, 1, 'yes'],
                [1, 1, 'yes'],
                [1, 0, 'no'],
                [0, 1, 'no'],
                [0, 1, 'no']]
    # dataSet最后一列指该生物是否属于鱼类
    # no surfacing指不浮出水面是否可以生存
    # flippers指是否有脚蹼
    labels = ['no surfacing', 'flippers']
    # change to discrete values
    return data_set, labels


def compute_shannon_etropy(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:  # the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannon_entropy = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannon_entropy -= prob * log(prob, 2)  # log base 2
    return shannon_entropy


def split_data_set(dataset, axis, value):
    retDataSet = []
    for featVec in dataset:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def choose_best_feature_to_split(dataset):
    numFeatures = len(dataset[0]) - 1  # 数据集最后一列作为标签
    baseEntropy = compute_shannon_etropy(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):  # 遍历所有的特征
        featList = [example[i] for example in dataset]  # create a list of all the examples of this feature
        uniqueVals = set(featList) # 利用集合函数得到不重复的值
        newEntropy = 0.0
        for value in uniqueVals:
            sub_data_set = split_data_set(dataset, i, value)
            prob = len(sub_data_set) / float(len(dataset))
            newEntropy += prob * compute_shannon_etropy(sub_data_set)
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        if (infoGain > bestInfoGain):  # 比较目前最好的信息增益
            bestInfoGain = infoGain  # 更新目前最好的信息增益
            bestFeature = i
    return bestFeature  # 返回在当前子集中可用于划分的最好的特征序号


def majority_vote_class(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def create_tree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]  # stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet
        return majority_vote_class(classList)
    bestFeat = choose_best_feature_to_split(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = create_tree(split_data_set(dataSet, bestFeat, value), subLabels)
    return myTree


def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree)[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat
    return classLabel


def stored_tree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()


def get_tree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)

myDat, labels = create_data_set()
myTree = create_tree(myDat, labels)
print(myTree)

plotTree.py完整代码如下:

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def get_num_leaves(myTree):
    numLeafs = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += get_num_leaves(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def get_tree_depth(myTree):
    maxDepth = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + get_tree_depth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def plot_nodes(nodeTxt, centerPt, parentPt, nodeType):
    create_plots.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                              xytext=centerPt, textcoords='axes fraction',
                              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plot_middle_text(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    create_plots.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plot_tree(myTree, parentPt, nodeTxt):
    numLeafs = get_num_leaves(myTree)
    depth = get_tree_depth(myTree)
    firstStr = list(myTree)[0]
    cntrPt = (plot_tree.xOff + (1.0 + float(numLeafs)) / 2.0 / plot_tree.totalW, plot_tree.yOff)
    plot_middle_text(cntrPt, parentPt, nodeTxt)
    plot_nodes(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plot_tree.yOff = plot_tree.yOff - 1.0 / plot_tree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plot_tree(secondDict[key], cntrPt, str(key))
        else:
            plot_tree.xOff = plot_tree.xOff + 1.0 / plot_tree.totalW
            plot_nodes(secondDict[key], (plot_tree.xOff, plot_tree.yOff), cntrPt, leafNode)
            plot_middle_text((plot_tree.xOff, plot_tree.yOff), cntrPt, str(key))
    plot_tree.yOff = plot_tree.yOff + 1.0 / plot_tree.totalD


def create_plots(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    create_plots.ax1 = plt.subplot(111, frameon=False, **axprops)
    plot_tree.totalW = float(get_num_leaves(inTree))
    plot_tree.totalD = float(get_tree_depth(inTree))
    plot_tree.xOff = -0.5 / plot_tree.totalW;
    plot_tree.yOff = 1.0
    plot_tree(inTree, (0.5, 1.0), '')
    plt.show()


def return_tree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                   ]
    return listOfTrees[i]

小结

至此,我们就利用ID3算法,实现了决策树,完成了对佩戴隐形眼镜的要求的机器学习应用。
ID3算法运行很快且便于理解,但是对于数值型数据将有些束手无策,尽管我们可以通过数值型数据转化为标称型数据,但如果存在太多的特征划分,ID3算法依然会面临大量的问题。
除了ID3算法外,还有C4.5 和 CART算法,后面博主还会继续补充相关的实现方法。
后续还会更新其他的机器学习算法,学习永无止境。

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值