机器学习决策树

本文详细介绍了决策树的原理、构造过程,包括信息增益的选择、数据集划分和使用Matplotlib实现的可视化。作者展示了如何使用Python构建决策树并利用Matplotlib注解绘制树形图。
摘要由CSDN通过智能技术生成

一、决策树的简述

1.1决策树的概念

        决策树是一种机器学习算法和数据挖掘技术,用于解决分类和回归问题。它模拟了人类在面对决策时的思考过程,通过一系列的判断条件和分支来构建一个树形结构。

        在决策树中,每个内部节点表示一个判断条件,用于将输入数据集划分为更小的子集。每个叶子节点表示一个类别(用于分类问题)或一个回归值(用于回归问题)。根据判断条件和输入数据的特征,决策树从根节点开始,沿着树的分支逐步向下遍历,直到达到一个叶子节点,然后将叶子节点所表示的类别或回归值作为预测结果。

1.2决策树的优缺点

        优点:计算复杂度不大,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关的特征数据。

        缺点:可能会产生过度匹配问题。

二、决策树的构造

2.1决策树的一般流程

(1)收集数据:可以以使用任何方法。

(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。

(3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。

(4)训练算法:构造树的数据结构。

(5)测试算法:使用经验树计算错误率。

(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在           含义。

2.2信息增益

        信息增益(Information Gain)是决策树算法中用于选择判断条件(特征)的一种度量方法。它衡量了在给定判断条件的情况下,对于数据集的分类带来的不确定性的减少程度。

        信息增益是基于信息论中的熵(Entropy)概念而来。熵可以衡量一个随机变量的不确定性。在决策树中,我们希望通过选择最优的特征来降低数据集的不确定性,从而使得分类更加准确。

计算熵的公式为:

# 计算给定数据集的熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)                          #获得数据集的行数
    labelCounts = {}                                   #用于保存每个标签出现次数的字典
    for featVec in dataSet:
        currentLabel = featVec[-1]                     #提取标签信息
        if currentLabel not in labelCounts.keys():     #如果标签未放入统计次数的字典,则添加进去
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1                 #标签计数
    shannonEnt = 0.0                                   #熵初始化
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries      #选择该标签的概率
        shannonEnt -= prob*log(prob, 2)                #以2为底求对数
    return shannonEnt                                  #返回经验熵

2.3划分数据集

        对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。

# 按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):                  #dataSet:带划分数据集   axis:划分数据集的特征   value:需要返回的特征的值
    retDataSet=[]                                      #创建新的list对象
    for featVec in dataSet:                            #遍历元素
        if featVec[axis] == value:                     #符合条件的,抽取出来
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1       #特征数量
    baseEntropy = calcShannonEnt(dataSet)   #计数数据集的香农熵
    bestInfoGain = 0.0                      #信息增益
    bestFeature = -1                        #最优特征的索引值
    for i in range(numFeatures):            #遍历数据集的所有特征
        featList = [example[i] for example in dataSet]          #获取dataSet的第i个所有特征
        uniqueVals = set(featList)                              #创建set集合{},元素不可重复
        newEntropy = 0.0                                        #信息熵
        for value in uniqueVals:                                #循环特征的值
            subDataSet = splitDataSet(dataSet, i, value)        #subDataSet划分后的子集
            prob = len(subDataSet) / float(len(dataSet))        #计算子集的概率
            newEntropy += prob * calcShannonEnt((subDataSet))
        infoGain = baseEntropy - newEntropy                     #计算信息增益
        print("第%d个特征的信息增益为%.3f" % (i, infoGain))      #打印每个特征的信息增益
        if (infoGain > bestInfoGain):                           #计算信息增益
            bestInfoGain = infoGain                             #更新信息增益,找到最大的信息增益
            bestFeature = i                                     #记录信息增益最大的特征的索引值
    return bestFeature                                          #返回信息增益最大特征的索引值

2.4递归构建决策树

采用多数表决的方法决定该叶子节点的分类。

# 统计出现次数最多的元素(类标签)
def majorityCnt(classList):
    classCount={}  #统计classList中每个类标签出现的次数
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    #根据字典的值降序排列
    return sortedClassCount[0][0]  #返回出现次数最多的类标签

 创建决策树。

# 创建决策树
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]           #取分类标签
    if classList.count(classList[0]) == len(classList):        #如果类别完全相同,则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:                                   #遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)               #选择最优特征
    bestFeatLabel = labels[bestFeat]                           #最优特征的标签
    myTree = {bestFeatLabel:{}}                                #根据最优特征的标签生成树
    del(labels[bestFeat])                                      #删除已经使用的特征标签
    featValues = [example[bestFeat] for example in dataSet]    #得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)                               #去掉重复的属性值
    for value in uniqueVals:                                   #遍历特征,创建决策树
        subLabels = labels[:]                                  #复制所有标签,这样树就不会弄乱现有标签
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

三、在Python中使用Matplotlib注解绘制树形图

3.1Matplotlib注解

使用文本注解绘制树节点。

# 使用文本注解绘制树节点
import matplotlib.pyplot as plt

# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,
                            textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

# 测试生成节点和箭头
def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    font = {'family': 'MicroSoft YaHei'}
    matplotlib.rc("font", **font)
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode('决策节点', (0.6, 0.1), (0.1, 0.5), decisionNode)
    plotNode('叶子节点', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

3.2构造注解树

获取叶节点的数目和树的层数。

# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs = 0  # 初始化
    firstStr = list(myTree.keys())[0]  # 获得第一个key值(根节点)
    secondDict = myTree[firstStr]  # 获得value值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典
            numLeafs += getNumLeafs(secondDict[key])  # 递归调用
        else:
            numLeafs += 1
    return numLeafs

# 获取树的深度
def getTreeDepth(myTree):
    maxDepth = 0  # 初始化
    firstStr = list(myTree.keys())[0]  # 获得第一个key值(根节点)
    secondDict = myTree[firstStr]  # 获得value值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典
            thisDepth = 1 + getTreeDepth(secondDict[key])  # 递归调用
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

函数retrieveTree输出预先存储的树信息。

# 决策树存储信息
def retrieveTree(i):
    listOfTrees = [{'买衣服': {
        0: {'应季衣服': {0: 'no', 1: {'面料': {0: {'余额': {'no': 'no', 'yes': 'yes'}}, 2: 'yes'}}}},
        1: {'反季衣服': {0: {'余额': {'no': 'no', 'yes': 'yes'}}, 1: 'yes', 2: 'yes'}}}}]
    return listOfTrees[i]

绘制决策树。

# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

# 画树
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 获取树高
    depth = getTreeDepth(myTree)  # 获取树深度
    firstStr = list(myTree.keys())[0]  # 这个节点的文本标签
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)  # plotTree.totalW, plotTree.yOff全局变量,追踪已经绘制的节点,以及放置下一个节点的恰当位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 标记子节点属性
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 减少y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

# 绘制决策树
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 创建一个新图形
    fig.clf()  # 清空绘图区
    font = {'family': 'MicroSoft YaHei'}
    matplotlib.rc("font", **font)
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

测试结果为: 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值