机器学习实战 -----决策树代码学习笔记(三)

ID3 算法是通过计算信息增益来进行类别的划分。
信息增益g(D,A)=H(D)-H(D|A),熵与经验条件熵的差。D是数据集,A是特征

信息增益的理解:
对于待划分的数据集D,其 entroy(前)是一定的,但是划分之后的熵 entroy(后)是不定的,entroy(后)越小说明使用此特征划分得到的子集的不确定性越小(也就是纯度越高),因此 entroy(前) - entroy(后)差异越大,说明使用当前特征划分数据集D的话,其纯度上升的更快。而我们在构建最优的决策树的时候总希望能更快速到达纯度更高的集合,这一点可以参考优化算法中的梯度下降算法,每一步沿着负梯度方法最小化损失函数的原因就是负梯度方向是函数值减小最快的方向。
同理:在决策树构建的过程中我们总是希望集合往最快到达纯度更高的子集合方向发展,因此我们总是选择使得信息增益最大的特征来划分当前数据集D。
导入所需的包

from math import log
import operator
import plotTree as treeplot    #这一个是用于绘制树的文件,命名为plotTree.py

一、计算给定数据集的香农熵

def calcShangnonEnt(dataSet):
    #计算数据集的长度
    lenData=len(dataSet)
    #定义空的字典,方便以后记性填充
    labelCounts={}
    #遍历数据集,featVec[-1]找到数据集中最后一列的分类结果,主要作用于测试数据集中
    for featVec in dataSet:
        currentLabel=featVec[-1]
        #将分类标签存入字典,加入字典之后将标签的数量+1,其总数用于后边概率的计算
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
 #进行熵的计算,若样本总数是m,表示a类的样本数为l,表示b类的样本数是k
 #则p(a)=l/m,  p(b)=k/m,熵的计算公式为h=-p(x)log(p(x)),一般是以2为底数,此处因为
 #样本数较多,所以进行累加计算(针对离散数据)       
    shangnonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/lenData
        shangnonEnt-=prob*log(prob,2)
        pass
    return shangnonEnt

在计算结果中,熵越高,说明混合的数据越多。

二、按照给定特征划分数据集

创建数据函数

def creatData():
    dataSet=[[1,1,'yes'],
          [0,1,'no'],
        [1,1,'yes'],
        [1,0,'no'],
        [0,1,'no']]
    labels=['no surfacing','flippers']
    
    return dataSet,labels

数据集中的各个元素也是列表,遍历每一个元素,发现符合要求的值,就将其添加到新建的列表中,,在if语句中,程序将符合特征的数据抽取出来

#dataSet 待划分的数据集,axis划分数据集的特征如0,value特征的返回值0,该方法主要是进行种类的划分,为后期计算信息增益做准备
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
     #例如 dataSet=[[1,1,'yes'],
     #     [0,1,'no'],
     #  [1,1,'yes'],
    #    [1,0,'no'],
     #   [0,1,'no']];
     #遍历dataSet,例如取到第一行,featVec =【1,1,‘yes’】若axis=0,则featVec[0]=1,
     #若取到dataSet的第二行则,featVec[0]=0.;将对应的值与参数value对比,如果相等,则进行下一步的操作
    for featVec in dataSet: 
        if featVec[axis]==value:
        #取featVec数组中从0开始到axis之间的数,注意labels[:0]=[]
            reducedFeatVec=featVec[:axis]
            #extend()和append(),a=[1,2,3],b=[4,5,6],a.append(b)=[1,2,3,[4,5,6]]
            #a.extend(b)=[1,2,3,4,5,6]
            #python 切片a=[1,2,3,4,5,6]
            #print(a[3:])---->[4, 5, 6]
            #print(a[1:3])---->[2, 3]
			#print(a[:3])---->[1, 2, 3]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
            pass
        pass
    #print(retDataSet)
    return retDataSet

3 选择最好的数据集划分方式

在函数调用中的数据(dataSet)需要满足一定的要求,第一个要求时数据必须是由列表元素组成的列表,而且所有列表元素都要有相同的数据长度;第二数据的最后一列或者每一个实例的最后一个元素是当前实例的类别标签,数据集一旦满足上述要求,就可以在函数第一行判定当前数据集包含多少特征属性,而无需限定list的数据模型,既可以是数字也可以是字符串,都不会影响计算。

#计算信息熵,返回信息增益最大的类别对应的标签
def chooseBestFeature(dataSet):
    numFeature=len(dataSet[0])-1
    #print(numFeature)#numFeature=2
    baseEntropy=calcShangnonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    #计算各个特征的信息熵
    for i in range(numFeature):
        
        #创建唯一的分类标签列表
        featList=[example[i] for example in dataSet]
        #print(featList)
        #featList是i=0  [1, 0, 1, 1, 0],i=1    [1, 1, 1, 0, 1]
        uniqueVals=set(featList) #去掉重复值  uniqueVals={0,1},{0,1}
        #print(uniqueVals)
        newEntropy=0.0
        #计算每一种划分方式的信息熵
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i,value)
            #i=0时value=0,subDataSet=[[1, 'no'], [1, 'no']]这种形式
            #1=0,value=1,subDataSet=[[1, 'yes'], [1, 'yes'], [0, 'no']]
            #i=1,value=0,[[1, 'no']]
            #i=1,value=1,[[1, 'yes'], [0, 'no'], [1, 'yes'], [0, 'no']]
            prob=len(subDataSet)/float(len(dataSet))#i=0,prob=0.4,0.6,  i=1 prob=0.2,0.8
            #calcShangnonEnt(subDataSet)分别计算subDataSet=[[1, 'no'], [1, 'no']],subDataSet=[[1, 'yes'], [1, 'yes'], [0, 'no']]等
            #时的信息熵
            newEntropy+=prob*calcShangnonEnt(subDataSet)
            #print(newEntropy)
            pass
        infoGain=baseEntropy-newEntropy
        #print(infoGain)
        #找出最好的信息增益
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
            pass
        pass
    return bestFeature

4、返回出现次数最多的分类名称

#类标签不是唯一的,为定义叶子节点,采用多数表决的方式进行叶子节点的分类
#classList 类似['yes', 'no', 'yes', 'no', 'no']的形式
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1    
        pass
    #     a = [1,36,9]
    # b = [[1,2,3],[4,5,6],[7,8,9]]
    # get_21 = operator.itemgetter(2,1)
    # print(get_21(a))-->(9, 36)
    # print(get_21(b))-->([7, 8, 9], [4, 5, 6])
    #classCount---->{'yes': 2, 'no': 3, 'ha': 1},operator.itemgetter(1)按数字大小排序
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    #sortedClassCount=[('no', 3), ('yes', 2), ('ha', 1)]
    return sortedClassCount[0][0]

5、递归创建决策树

输入两个参数:数据集合标签列表,标签列表包含了数据集中所有特征的标签。classList包含了所有特征的标签。递归停止的第一个条件是所有类别标签完全相同,直接返回该标签。
停止的第二个条件是:用完了所有的特征,仍然不能将数据集划分成包含唯一类别的分组。

def createTree(dataSet,lables):
    classList=[example[-1] for example in dataSet]
#     print(classList,"aa")
    #只有一个类别,则停止继续划分,返回类标签。classList=['yes', 'no', 'yes', 'no', 'no','ha'],
    #classList.count(classList[0])是yes,计算yes的数量
    if classList.count(classList[0])==len(classList):
        return classList[0]
  
    #遍历完所有特征,返回出现次数最多的
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    #,使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,则选用
    #出现次数最多的的类别作为返回值,bestFeat中将得到数据集中选择的最好的特征。
    bestFeat=chooseBestFeature(dataSet) 
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
#     print("+++++")
#     print(labels)
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    #print(uniqueVals)
    #最 后 代 码 遍 历 当 前 选 择 特 征 包 含 的 所 有 属 性 值 ,在 每 个 数 据 集 划 分 上 递 归 调 用 函 数
#createTree ( ) ,得到的返回值将被插人到字典变量0 ^ ^ 6 中,因此函数终止执行时,宇典中将
#会嵌套很多代表叶子节点信息的字典数据。
    for value in uniqueVals:
        subLabels=labels[:]
        #print(subLabels)
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
        pass
    return myTree

6、进行预测

#进行预测
def classify(inputTree,featLabels,testVec):
    #得到输入数据第一个key值
    firstStr =list(inputTree.keys())[0]
    #得到key值对应的value,如果是字典可以进一步划分
    secondDict=inputTree[firstStr]
    #firstStr是key值,也就是划分的特征值,如['no surfacing','flippers'],index(firstStr)得到对应的位置如
    #no surfacing的索引值为0,flippers的索引值是1
    featIndex=featLabels.index(firstStr)
    # featLabels=['no surfacing','flippers']遍历字典,如果输入的testVec
    #secondDict===={0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
    #secondDict.keys==[0,1]
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=="dict":
                classLabel=classify(secondDict[key],featLabels,testVec)
                pass
            else:
                classLabel=secondDict[key]
                pass
            pass
        pass
    return classLabel

7、运行

dataSet,labels=creatData()
mytree=treeplot.retrieveTree(0)
treeplot.createPlot(mytree)#画图的文件,如果不进行绘图,则可以不要
classify(mytree,labels,[1,0])

8、结果

在这里插入图片描述
9、附------画图函数代码

#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import matplotlib.pyplot as plt
decisionNode =dict(boxstyle='sawtooth',fc='0.8')
arrow_args=dict(arrowstyle="<-")
leafNode=dict(boxstyle="round4",fc="0.8")
#以文本坐标(-2,-2)
#ha="center"  在水平方向上,方框的中心在为(-2,0)
#va="center"  在垂直方向上,方框的中心在为(0,-2)
# bbox={}  代表对方框的设置
#         { 
#             boxstyle= '' 代表边框的类型
#                     round 圆形方框
#                     rarrow箭头
#             fc  背景颜色   英文首字母 w -whiite r-red
#             ec 边框线的透明度  数字或颜色的首字母
#             alpha 字体的透明度
#             lw 线的粗细
#             rotation  角度

# xy=(横坐标,纵坐标)  箭头尖端
#     xytext=(横坐标,纵坐标) 文字的坐标,指的是最左边的坐标
#     arrowprops= {
#         facecolor= '颜色',
#         shrink = '数字' <1  收缩箭头
#     }


 #得到叶子节点的数目   
def getNumLeafs(myTree):
    numLeafs=0
    #python3中放弃了 dict.keys.index的用法,而是进行强制类型转换,可以将其变为list,然后获取字典的第一个key值
    firstStr =list(myTree.keys())[0]
    #根绝key值,找到key对应的value
#   myTree---->  {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#   myTree.keys()---->  dict_keys(['no surfacing'])
#  list(myTree.keys())[0]----->   no surfacing
#  myTree[firstStr]------->{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    secondDict=myTree[firstStr]
    print(secondDict)
    print(secondDict.keys())
    #递归调用本身,如果第一个key对应的value中还有字典则继续进行遍历
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            #当一个key对用的value是字典的时候就进行递归调用,不进行节点的计算,因为字典还有子节点
            #并不是叶子节点,叶子节点没有后续的分支
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
            pass
    return numLeafs
    pass
#得到树的深度
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
            pass
        else:
            thisDepth=1
            if thisDepth>maxDepth:
                maxDepth=thisDepth
                pass
            pass
        pass
    return maxDepth
# string:图形内容的注释文本, xy:被注释图形内容的位置坐标,xytext:注释文本的位置坐标
#weight:注释文本的字体粗细风格,color:注释文本的字体颜色,arrowprops:指示被注释内容的箭头的属性字典
#https://blog.csdn.net/qq_30638831/article/details/79938967

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',
    va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
    pass
def createPlot():
    #figsize=(16,6) 窗口大小
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    #subplot(111)一行一列一个图
     # createPlot.ax1为全局变量,绘制图像句柄,理解为一行一列的第一个图
    #frameon=False窗口无框(所绘图形对象大小等于最终图片对象的大小)
    createPlot.ax1=plt.subplot(111,frameon=False)
    plotNode("a decision node",(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode("a leaf node",(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

def retrieveTree(i):
    listOfTrees=[{"no surfacing":{0:'no',1:{"flippers":{0:"no",1:"yes"}}}},
                {'no surfacing':{0:"no",1:{"flippers":{0:{'head':{0:"no",1:"yes"}},1:"no"}}}}]
    return listOfTrees[i]

#显示文字函数
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    #text(xMid,yMid,txtString),在指定位置显示文字
    createPlot.ax1.text(xMid,yMid,txtString)
    pass
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=="dict":
            plotTree(secondDict[key],cntrPt,str(key))
            pass
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
            pass
        pass
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
    pass
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])# 定义横纵坐标轴,无内容  
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)# 绘制图像,无边框,无坐标轴  
    #树宽
    plotTree.totalW=float(getNumLeafs(inTree)) #全局变量宽度 = 叶子数
    #树深
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()
    


    
    
myTree=retrieveTree(1)
createPlot(myTree)
getNumLeafs(myTree)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值