小白机器学习-2 决策树学习

小白机器学习-2 决策树学习-代码解析

2.1 决策树介绍

决策树类似小时候玩的猜谜游戏,给个范围之后,一人猜,一人回答对与错,进而使得范围越来越小,最后猜中结果
在上述游戏中,每给出一回答后都会缩小范围,在实际处理数据集时,数据可能会有许多中特征,这些特征不同程度的影响着数据的类别。例如《机器学习实战》判断生物是否为鱼类的例子:

序号不浮出水面是否可以生存是否有脚蹼属于鱼类
1
2
3
4
5

在上表中,生物有两个特征:不浮出水面是否可以生存,是否有脚蹼。为了构建决策树,我们需要知道这两个特征哪一个对是否为鱼类影响力更加大一些,因此涉及到一个描述信息重要程度的概念:信息熵

2.2 信息熵

A. 其中P(x)是事件x发生的概率,H(x)就是信息熵,也就是事件所包含的信息量
关于信息熵公式大佬的详细解释:信息熵
B. 在上表判断生物是否为鱼类的例子中,每个特征值划分数据的能力不同,也就是信息增益不同。什么是信息增益,以及如何计算:信息增益
C. 为了更好的划分数据集,我们需要优先选择信息增益大的特征值来划分。举个栗子:在划分学生信息的时候,往往是:学校—年级—班级—性别。学校的划分能力是较之于其他的特征是最强的,所以学校排老大。
于是在划分数据集之前,我们首先得选出,信息增益最优的特征值,因此我们构建决策树的步骤分为:

  1. 计算数据原始的信息熵(在计算不同特征值划分的信息增益时需要使用)
  2. 计算根据不同特征值 划分数据所得的信息熵,并选择信息增益最大的作为划分方式
  3. 构建决策树

2.3 计算信息熵

"""计算给定数据集的香农熵 """
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)     # 记录数据集中的实例个数 dataSet 链表形式的数据集
    labelCounts = {}              # 创建空字典,key存储类别信息,值为该类别对应的是数量个数
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 取出最后一列,例如:(1,1,"yes")实战中判断是否是鱼的例子
        if currentLabel not in labelCounts.keys():  # 如果发现类别不在字典labelCounts的键中
            labelCounts[currentLabel] = 0           # 将该类别存储于字典中,且初始化为0
        labelCounts[currentLabel] += 1              # 该类别的数量加1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries     # prob 为该类别出现的概率
        shannonEnt -= prob * log(prob,2)                # shannonEnt 利用信息熵公式计算
    return shannonEnt

数据集dataSet,labelCounts存储类别以及对应的个数,遍历数据集,利用currentLabel来临时存储,在labelCounts中查找,做不在已有的labelCounts则添加新类别,初始化数量为0
for循环中,prob计算该类别出现的概率,函数返回shannonEnt为类别所对应的信息熵

2.4划分数据集

"""按照给定的特征划分数据集"""
def splitDataSet(dataSet, axis, value):    # dataSet指定的数据集,划分数据采集的特征,特征的返回值
    retDataSet = []                        # 创建新列表用于存储
    for featVec in dataSet:                # 遍历数据集
        if featVec[axis] == value:         # 若特征值与划分依据的特征值相同
            reducedFeatVec = featVec[:axis] # 截取特征值之前的元素
            reducedFeatVec.extend(featVec[axis + 1:]) # 添加特征值之后的元素
            retDataSet.append(reducedFeatVec)         # retDataSet接收划分之后的数据集
    return retDataSet

2.5选择最优划分数据方式

"""循环计算香农熵和splitDataSet函数 ----找到最好的划分方式"""
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1                     # 求特征值的数量即第一行有多少列,因为最后一行是结果值,所以-1
    baseEntropy = calcShannonEnt(dataSet)                 # 保存一开始数据集的信息熵
    bestInfoGain = 0.0; bestFeature = -1                  # 初始化最优信息增益值,最优的特征值
    # print("初始信息熵:%f" % baseEntropy)
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]    # 获取第i个所有特征值,即第i列所有元素 featList = [1,1,1,0,0],[1,1,0,1,1]
        uniqueVals = set(featList)                        # 利用set(),{0,1}
        newEntropy = 0.0
        for value in uniqueVals:                          # 遍历uniqueVals中的所有取值,{0,1}
            subDataSet = splitDataSet(dataSet, i, value)  # 对于第i个特征值,不同取值下的数据划分:(dataSet,0,0) --> [1,'no'],[1,'no']
            prob = len(subDataSet) / float(len(dataSet))  # 计算在这种特征值取值下的数据占比
            newEntropy += prob * calcShannonEnt(subDataSet) # 计算该特征值对应的信息熵
        # print("第%d" % i + "个特征值得信息熵是:%f" % newEntropy)
        infoGain = baseEntropy - newEntropy                 # 计算信息增益
        if(infoGain > bestInfoGain):                        #若当前信息增益大于最优增益,则更新最优增益,且记录最优增益特征值的位置
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature                                      #返回最优增益的位置

2.6 运用递归来构建决策树

"""创建树的代码"""
def createTree(dataSet,labels):  # 参数(数据集,标签)
    classList = [example[-1] for example in dataSet]       # 取数据集中最后一列,(1,1,'yes')classlist记录类别,classlist=[y,y,n,n,n]
    if classList.count(classList[0])==len(classList):      # list.count(obj)统计list中obj的个数,若classlist中全是相同元素
        return classList[0]
    if len(dataSet[0]) == 1:                               # 若数据集的特征值个数为1
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)           # 调用函数得到最优增益位置 例bestFeat = 0
    bestFeatLabel = labels[bestFeat]                       # 得到划分标签 例beatFeatLabel=labels[0]=no surfacing
    myTree = {bestFeatLabel:{}}                            # 利用最佳标签来构建myTree字典 myTree = {no surfacing:{}}
    del (labels[bestFeat])                                 # 删除标签中的最佳标签 labels=['flippers']
    featValues = [example[bestFeat] for example in dataSet]# featValues = [1,1,1,0,0]
    uniqueVals = set(featValues)                           # 去重 featValues = [0,1]
    for value in uniqueVals:
        subLabels = labels[:]                              # subLabels接收剩余标签 subLabels = ['flipper']
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
        """myTree['no surfacing'][0] = createTree(splitDataSet(dataSet,0,0),'flipper')
              splitDataSet(dataSet,0,0) = {[1,'no'],[1,'no']}
           myTree['no surfacing'][1] = createTree(splitDataSet(dataSet,0,1),'flipper')
              splitDataSet(dataSet,0,0) = {[1,'yes'],[1,'yes'],[0,'no']}
        """
    return myTree

2.7 决策树使用:分类

"""决策树分类函数"""
def classify(inputTree,featLabels,testVec):
    """
    参数:决策树,类别,测试数据
    决策树:{'no surfacing':{0:'no', 1:{'flipper':{0:'no', 1:'yes'}
                                                                   }
                                                                    }
                                                                     }
    类别:Labels = {'no surfacing','flipper'},测试数据 [1,0]
    类似树的遍历
    """
    firstStr = list(inputTree.keys())[0]           # 去除决策树字典中第一个字符 'no surfacing'
    secondDict = inputTree[firstStr]                         # secondDict = {0:'no', 1:{'flipper':{0:'no', 1:'yes'}}}
    featIndex = featLabels.index(firstStr)         # featIndex = 0
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key])==dict:        #若secondDict[key][类型仍是字典则继续递归判断
                classLabel = classify(secondDict[key],featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

2.8 绘制决策树代码

将构造的决策树,图形化

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")


# 使用文本注释绘制树节点
def plotNode(nodeTxt, centerPt, parentPt,nodeType):
    createPlot.axl.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.axl = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5,1.0),'')
    # plotNode('a decision node',(0.5,0.1), (0.1, 0.5), decisionNode)
    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

# 获取叶节点的数目和层数
def getNumLeafs(myTree):
    numLeafs = 0
    # a = list(myTree.keys())
    # print(a)
    keys = list(myTree.keys())
    firstStr = keys[0]                                 # 取第一个结点元素
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            numLeafs += getNumLeafs((secondDict[key]))
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

2.9 存储决策树以及测试集

"""存储决策树"""
def storeTree(Tree,filename):
    import pickle
    with open(filename,'wb') as f:
        pickle.dump(Tree,f)
def grabTree(filename):
    import pickle
    with open(filename,'rb') as f:
        return pickle.load(f)

"""测试集"""
print("开始测试")
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]         # 注意这里是readlines(),开始忘记加s,导致程序错误
lensesLabels = ['age','prescript','astigmatic','tearRate']
print(lenses)
print(lensesLabels)
lensesTree = createTree(lenses,lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)

总结

在学习决策树的过程中,理解信息熵,信息增益很重要,是最优划分数据集的基础。在构建决策树时,利用递归,字典嵌套来构建,亲自逐步语句运行可加深对构造决策树的理解。
在学习过程中,犯了一个低级错误,在读取文件时手误使用了readline(),而应该使用readlines(),这个错误花了很久的事件才解决。
在数据可视化方面比较欠缺,因此没有写代码注释。。。等习得之后补上

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值