《机器学习实战》— 决策树

目录

一、决策树相关概念介绍

1、什么是决策树/判定是?

2、决策树优缺点

3、熵概念

二、决策树归纳算法

1、举个栗子

2、算法

3、树剪枝叶(避免overfitting)

三、代码实现

1、创建数据集

2、计算给定数据集的香农熵

3、按照给定特征划分数据集

4、选择最好的数据集划分方式

5、测试

6、多数表决

7、创建决策树

8、判定数据属于哪个分类

9、测试

10、使用pickle模块存储/读取决策树

四、绘制决策树

1、需引入matplotlib

2、获取叶子节点数目和树的层数

3、定义一些样式

4、决定树的绘制(逻辑绘制)

5、实际绘制树

6、创建数据集

7、测试


一、决策树相关概念介绍

1、什么是决策树/判定是?

     判定树是一个类似于流程图的树结构:其中,每个内部结点表示在一个属性上的测试,每个分支代表一个属性输出,而每个树叶结点代表类或类分布。树的最顶层是根结点。

2、决策树优缺点

决策树的优点:直观,便于理解,小规模数据集有效     

决策树的缺点: 处理连续变量不好; 类别较多时,错误增加的比较快;可规模性一般

3、熵概念

   一条信息的信息量大小和它的不确定性有直接的关系,要搞清楚一件非常非常不确定的事情,或者是我们一无所知的事情,需要了解大量信息==>信息量的度量就等于不确定性的多少

二、决策树归纳算法

          选择属性判断结点

          信息获取量(Information Gain):Gain(A) = Info(D) - Infor_A(D)

          通过A来作为节点分类获取了多少信息

1、举个栗子

 

2、算法

3、树剪枝叶(避免overfitting)

(1)先剪枝:比如比例达到多少后就不考虑分支了

(2)后剪枝:树建立完成后如果太大了再考虑剪枝

三、代码实现

1、创建数据集

#创建数据集
def createDataSet():
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

2、计算给定数据集的香农熵

#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  #计算数据集中实例总数
    labelCounts = {}  #存储标签
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): #如果当前键不存在,就扩展字典并将当前键值加入字典
            labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1  #记录出现的总次数--频率
    shannonEnt = 0.0  #熵值初始化为0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)  #以2为底求对数
    return shannonEnt

3、按照给定特征划分数据集

#按照给定特征划分数据集
#dataSet:待划分的数据集
#axis:划分数据集的特征
#value:特征的返回值
def splitDataSet(dataSet,axis,value):
    #注意:python语言在函数中传递的是列表的引用,在函数内部对列表对象进行修改,将会影响该列表对象的整个生存周期
    #为了消除这个不良的影响,我们需要在函数开始声明一个新列表对象
    retDataSet=[]
    for featVec in dataSet:
        #当我们按照某个特征划分数据集时,就需要将所有符合要求的元素抽取出来
        if featVec[axis]==value: #将符合特征的数据抽取出来
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

4、选择最好的数据集划分方式

#选择最好的数据集划分方式
#该函数实现选取特征,划分数据集,计算得出最好的划分数据集的特征
def chooseBestFeatureToSpit(dataSet):
    numFeatures = len(dataSet[0]) - 1  #表示特征值的列数
    baseEntropy = calcShannonEnt(dataSet)   #计算整个数据集的原始香农熵
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet] #创建唯一的分类标签列表,将数据集中蓑鲉第i个特征值或者所有可能存在的值写入这个新的list
        uniqueVals = set(featList)#集合和列表的区别是可以最快得到列表中唯一元素
        newEntropy = 0.0
        #计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        #计算最好的信息增益
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    #返回最好特征划分的索引值,从0开始
    return bestFeature

5、测试

import tree
import treePlotter
#创建数据集
dataSet,labels=tree.createDataSet()
print(dataSet)
#熵的计算函数
shannonEnt=tree.calcShannonEnt(dataSet)
print(shannonEnt)
#划分数据集
retDataSet=tree.splitDataSet(dataSet, 0, 1)  #表示选择第一列值为1的
print(retDataSet)
#retDataSet=tree1.splitDataSet(dataSet,0,0)  #表示选择第一列值为0的
#print(retDataSet)

6、多数表决

#多数表决
#如果数据集已经处理了所有属性,但是类标签依然不是唯一的此时我们需要决定如何定义该叶子结点,在这种情况下,我们通常会采用多数表决的方式决定该叶子结点的分类
#classList:标签列表
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

7、创建决策树

#创建决策树
#dataSet:数据集
#labels:标签列表-包含数据集中所有特征的标签
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet] #classList包含数据集的所有类标签
    if classList.count(classList[0]) == len(classList):#递归函数停止的第一个条件:所有类标签完全相同,则直接返回该类标签
        return classList[0]
    if len(dataSet[0]) == 1 :#递归函数停止的第二个条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组
        return majorityCnt(classList)  #使用多数表决方法,挑选出现次数最多的类别作为返回值
    bestFeat=chooseBestFeatureToSpit(dataSet)#选择最好特征划分的索引值
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}#创建树,字典类型存储树信息
    del(labels[bestFeat]) #此处,标签列表要随着子集变化而变化
    featValues=[example[bestFeat] for example in dataSet]#找出决策节点后,继续深入分析特征值
    uniqueVals=set(featValues)#找出最佳分类的特征值的所有可能取值存储
    for value in uniqueVals:
        subLabels=labels[:]#复制类标签,此时已分类好的特征值被删除了
        #此时数据集和类标签都已更新,再递归调用
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

8、判定数据属于哪个分类

#使用决策树的分类函数
#具体判断数据属于哪个分类
#inputTree:构造好的树--{'flippers': {0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}}}
#featLabels:分类标签
#testVec:需要测试的数据,看它属于哪个分类
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0] #flippers  表示第一个分类的标签
    secondDict = inputTree[firstStr]#{0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}}  表示第一个分类标签对应的列表
    featIndex = featLabels.index(firstStr)#1 表示第一个分类的标签flippers在featLabels中的索引
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

9、测试

myTree=tree.createTree(dataSet, labels)
print(myTree)
#使用决策树的分类函数
classLabel=tree.classify(myTree,labels,[1,0])
print(classLabel)

10、使用pickle模块存储/读取决策树

#使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename, "wb")
    pickle.dump(inputTree, fw)
    fw.close()

#使用pickle模块读取决策树
def grabTree(filename):
    import pickle
    fr=open(filename,'rb')
    return pickle.load(fr)
myStoreTree=tree.grabTree('classifierStorage.txt')
print(myStoreTree)

四、绘制决策树

1、需引入matplotlib

#绘制树
import matplotlib as mpl
mpl.use('TkAgg')
import matplotlib.pyplot as plt

2、获取叶子节点数目和树的层数

#获取叶节点的数目
#myTree格式:{'flippers': {0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}}}
def getNumLeafs(myTree):
    numLeafs=0
    #注意:你将结果传递somedict.keys()给函数。在Python 3中,dict.keys不返回一个列表,而是一个表示库键和视图(类似于set)的类集对象,不支持索引。
    #要解决该问题,请使用收集密钥并使用list(somedict.keys())密钥。
    firstStr=list(myTree.keys())[0] #第一个关键字是第一次划分数据集的类标签:flippers
    secondDict=myTree[firstStr]#根据第一个关键字获取下层目录
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#判断子结点是否为字典类型
            numLeafs+=getNumLeafs(secondDict[key])#如果是字典类型,则该节点也是一个判断节点,需递归调用
        else:#如果不是字典类型,则说明是叶子节点
            numLeafs+=1
    return numLeafs

#获取数的层数
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

3、定义一些样式

#这个是用来一注释形式绘制节点和箭头线,可以不用管
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#这个是用来绘制线上的标注,简单
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

4、决定树的绘制(逻辑绘制)

# 重点,递归,决定整个树图的绘制,难(自己认为)
def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]  # the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key], cntrPt, str(key))  # recursion
        else:  # it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

5、实际绘制树

# 这个是真正的绘制,上边是逻辑的绘制
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False)  # no ticks
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

6、创建数据集

#这个是用来创建数据集即决策树
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}},3:'maybe'}}
                  ]
    return listOfTrees[i]

7、测试

import treePlotter
treePlotter.createPlot(treePlotter.retrieveTree(2))

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值