ID3 算法案例

1. 决策树的基本认识

 

   决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对

   象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能

   的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅

   有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。接下来讲解ID3算法。

 

 

2. ID3算法介绍

 

   ID3算法是决策树的一种,它是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法

   即Iterative Dichotomiser 3迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个

   算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总

   是生成最小的树型结构,而是一个启发式算法。

 

   在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息

   增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍

   历可能的决策空间。

 

 

3. 信息熵与信息增益

 

   在信息增益中,重要性的衡量标准就是看特征能够为分类系统带来多少信息,带来的信息越多,该特征越

   重要。在认识信息增益之前,先来看看信息熵的定义

 

   这个概念最早起源于物理学,在物理学中是用来度量一个热力学系统的无序程度,而在信息学里面,熵

   是对不确定性的度量。在1948年,香农引入了信息熵,将其定义为离散随机事件出现的概率,一个系统越

   是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序

   化程度的一个度量。

 

   假如一个随机变量的取值为,每一种取到的概率分别是,那么

   的熵定义为

 

             

 

   意思是一个变量的变化情况可能越多,那么它携带的信息量就越大。

 

   对于分类系统来说,类别是变量,它的取值是,而每一个类别出现的概率分别是

 

             

 

   而这里的就是类别的总数,此时分类系统的熵就可以表示为

 

             

 

   以上就是信息熵的定义,接下来介绍信息增益

 

   信息增益是针对一个一个特征而言的,就是看一个特征,系统有它和没有它时的信息量各是多少,两者

   的差值就是这个特征给系统带来的信息量,即信息增益

案例:

本文使用的Python库包括

  • numpy
  • pandas
  • math
  • operator
  • matplotlib

本文所用的数据如下:
                  Idx色泽根蒂敲声纹理脐部触感               密度           含糖率              label
1青绿蜷缩浊响清晰凹陷硬滑0.6970.461
2乌黑蜷缩沉闷清晰凹陷硬滑0.7740.3761
3乌黑蜷缩浊响清晰凹陷硬滑0.6340.2641
4青绿蜷缩沉闷清晰凹陷硬滑0.6080.3181
5浅白蜷缩浊响清晰凹陷硬滑0.5560.2151
6青绿稍蜷浊响清晰稍凹软粘0.4030.2371
7乌黑稍蜷浊响稍糊稍凹软粘0.4810.1491
8乌黑稍蜷浊响清晰稍凹硬滑0.4370.2111
9乌黑稍蜷沉闷稍糊稍凹硬滑0.6660.0910
10青绿硬挺清脆清晰平坦软粘0.2430.2670
11浅白硬挺清脆模糊平坦硬滑0.2450.0570
12浅白蜷缩浊响模糊平坦软粘0.3430.0990
13青绿稍蜷浊响稍糊凹陷硬滑0.6390.1610
14浅白稍蜷沉闷稍糊凹陷硬滑0.6570.1980
15乌黑稍蜷浊响清晰稍凹软粘0.360.370
16浅白蜷缩浊响模糊平坦硬滑0.5930.0420
17青绿蜷缩沉闷稍糊稍凹硬滑0.7190.1030

由于我没搞定matplotlib的中文输出,因此将中文字符全换成了英文,如下:
Idxcolorrootknockstexturenaveltouchdensitysugar_ratiolabel
1dark_greencurl_uplittle_heavilydistinctsinkinghard_smooth0.6970.461
2blackcurl_upheavilydistinctsinkinghard_smooth0.7740.3761
3blackcurl_uplittle_heavilydistinctsinkinghard_smooth0.6340.2641
4dark_greencurl_upheavilydistinctsinkinghard_smooth0.6080.3181
5light_whitecurl_uplittle_heavilydistinctsinkinghard_smooth0.5560.2151
6dark_greenlittle_curl_uplittle_heavilydistinctlittle_sinkingsoft_stick0.4030.2371
7blacklittle_curl_uplittle_heavilylittle_blurlittle_sinkingsoft_stick0.4810.1491
8blacklittle_curl_uplittle_heavilydistinctlittle_sinkinghard_smooth0.4370.2111
9blacklittle_curl_upheavilylittle_blurlittle_sinkinghard_smooth0.6660.0910
10dark_greenstiffcleardistinctevensoft_stick0.2430.2670
11light_whitestiffclearblurevenhard_smooth0.2450.0570
12light_whitecurl_uplittle_heavilyblurevensoft_stick0.3430.0990
13dark_greenlittle_curl_uplittle_heavilylittle_blursinkinghard_smooth0.6390.1610
14light_whitelittle_curl_upheavilylittle_blursinkinghard_smooth0.6570.1980
15blacklittle_curl_uplittle_heavilydistinctlittle_sinkingsoft_stick0.360.370
16light_whitecurl_uplittle_heavilyblurevenhard_smooth0.5930.0420
17dark_greencurl_upheavilylittle_blurlittle_sinkinghard_smooth0.7190.1030

字符的含义可自行对照上下两表

决策树生成的代码参照 机器学习实战 第三章的代码,但是书上第三章是针对离散特征的,下面程序中对其进行了修改,使其能用于同时包含离散与连续特征的数据集。

决策树生成代码如下:

  1. # -*- coding: utf-8 -*-  
  2.   
  3. from numpy import *  
  4. import numpy as np  
  5. import pandas as pd  
  6. from math import log  
  7. import operator  
  8.   
  9.   
  10.   
  11. #计算数据集的香农熵  
  12. def calcShannonEnt(dataSet):  
  13.     numEntries=len(dataSet)  
  14.     labelCounts={}  
  15.     #给所有可能分类创建字典  
  16.     for featVec in dataSet:  
  17.         currentLabel=featVec[-1]  
  18.         if currentLabel not in labelCounts.keys():  
  19.             labelCounts[currentLabel]=0  
  20.         labelCounts[currentLabel]+=1  
  21.     shannonEnt=0.0  
  22.     #以2为底数计算香农熵  
  23.     for key in labelCounts:  
  24.         prob = float(labelCounts[key])/numEntries  
  25.         shannonEnt-=prob*log(prob,2)  
  26.     return shannonEnt  
  27.   
  28.   
  29. #对离散变量划分数据集,取出该特征取值为value的所有样本  
  30. def splitDataSet(dataSet,axis,value):  
  31.     retDataSet=[]  
  32.     for featVec in dataSet:  
  33.         if featVec[axis]==value:  
  34.             reducedFeatVec=featVec[:axis]  
  35.             reducedFeatVec.extend(featVec[axis+1:])  
  36.             retDataSet.append(reducedFeatVec)  
  37.     return retDataSet  
  38.   
  39. #对连续变量划分数据集,direction规定划分的方向,  
  40. #决定是划分出小于value的数据样本还是大于value的数据样本集  
  41. def splitContinuousDataSet(dataSet,axis,value,direction):  
  42.     retDataSet=[]  
  43.     for featVec in dataSet:  
  44.         if direction==0:  
  45.             if featVec[axis]>value:  
  46.                 reducedFeatVec=featVec[:axis]  
  47.                 reducedFeatVec.extend(featVec[axis+1:])  
  48.                 retDataSet.append(reducedFeatVec)  
  49.         else:  
  50.             if featVec[axis]<=value:  
  51.                 reducedFeatVec=featVec[:axis]  
  52.                 reducedFeatVec.extend(featVec[axis+1:])  
  53.                 retDataSet.append(reducedFeatVec)  
  54.     return retDataSet  
  55.   
  56. #选择最好的数据集划分方式  
  57. def chooseBestFeatureToSplit(dataSet,labels):  
  58.     numFeatures=len(dataSet[0])-1  
  59.     baseEntropy=calcShannonEnt(dataSet)  
  60.     bestInfoGain=0.0  
  61.     bestFeature=-1  
  62.     bestSplitDict={}  
  63.     for i in range(numFeatures):  
  64.         featList=[example[i] for example in dataSet]  
  65.         #对连续型特征进行处理  
  66.         if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':  
  67.             #产生n-1个候选划分点  
  68.             sortfeatList=sorted(featList)  
  69.             splitList=[]  
  70.             for j in range(len(sortfeatList)-1):  
  71.                 splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)  
  72.               
  73.             bestSplitEntropy=10000  
  74.             slen=len(splitList)  
  75.             #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点  
  76.             for j in range(slen):  
  77.                 value=splitList[j]  
  78.                 newEntropy=0.0  
  79.                 subDataSet0=splitContinuousDataSet(dataSet,i,value,0)  
  80.                 subDataSet1=splitContinuousDataSet(dataSet,i,value,1)  
  81.                 prob0=len(subDataSet0)/float(len(dataSet))  
  82.                 newEntropy+=prob0*calcShannonEnt(subDataSet0)  
  83.                 prob1=len(subDataSet1)/float(len(dataSet))  
  84.                 newEntropy+=prob1*calcShannonEnt(subDataSet1)  
  85.                 if newEntropy<bestSplitEntropy:  
  86.                     bestSplitEntropy=newEntropy  
  87.                     bestSplit=j  
  88.             #用字典记录当前特征的最佳划分点  
  89.             bestSplitDict[labels[i]]=splitList[bestSplit]  
  90.             infoGain=baseEntropy-bestSplitEntropy  
  91.         #对离散型特征进行处理  
  92.         else:  
  93.             uniqueVals=set(featList)  
  94.             newEntropy=0.0  
  95.             #计算该特征下每种划分的信息熵  
  96.             for value in uniqueVals:  
  97.                 subDataSet=splitDataSet(dataSet,i,value)  
  98.                 prob=len(subDataSet)/float(len(dataSet))  
  99.                 newEntropy+=prob*calcShannonEnt(subDataSet)  
  100.             infoGain=baseEntropy-newEntropy  
  101.         if infoGain>bestInfoGain:  
  102.             bestInfoGain=infoGain  
  103.             bestFeature=i  
  104.     #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理  
  105.     #即是否小于等于bestSplitValue  
  106.     if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':        
  107.         bestSplitValue=bestSplitDict[labels[bestFeature]]          
  108.         labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)  
  109.         for i in range(shape(dataSet)[0]):  
  110.             if dataSet[i][bestFeature]<=bestSplitValue:  
  111.                 dataSet[i][bestFeature]=1  
  112.             else:  
  113.                 dataSet[i][bestFeature]=0  
  114.     return bestFeature  
  115.   
  116. #特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票  
  117. def majorityCnt(classList):  
  118.     classCount={}  
  119.     for vote in classList:  
  120.         if vote not in classCount.keys():  
  121.             classCount[vote]=0  
  122.         classCount[vote]+=1  
  123.     return max(classCount)  
  124.   
  125. #主程序,递归产生决策树  
  126. def createTree(dataSet,labels,data_full,labels_full):  
  127.     classList=[example[-1for example in dataSet]  
  128.     if classList.count(classList[0])==len(classList):  
  129.         return classList[0]  
  130.     if len(dataSet[0])==1:  
  131.         return majorityCnt(classList)  
  132.     bestFeat=chooseBestFeatureToSplit(dataSet,labels)  
  133.     bestFeatLabel=labels[bestFeat]  
  134.     myTree={bestFeatLabel:{}}  
  135.     featValues=[example[bestFeat] for example in dataSet]  
  136.     uniqueVals=set(featValues)  
  137.     if type(dataSet[0][bestFeat]).__name__=='str':  
  138.         currentlabel=labels_full.index(labels[bestFeat])  
  139.         featValuesFull=[example[currentlabel] for example in data_full]  
  140.         uniqueValsFull=set(featValuesFull)  
  141.     del(labels[bestFeat])  
  142.     #针对bestFeat的每个取值,划分出一个子树。  
  143.     for value in uniqueVals:  
  144.         subLabels=labels[:]  
  145.         if type(dataSet[0][bestFeat]).__name__=='str':  
  146.             uniqueValsFull.remove(value)  
  147.         myTree[bestFeatLabel][value]=createTree(splitDataSet\  
  148.          (dataSet,bestFeat,value),subLabels,data_full,labels_full)  
  149.     if type(dataSet[0][bestFeat]).__name__=='str':  
  150.         for value in uniqueValsFull:  
  151.             myTree[bestFeatLabel][value]=majorityCnt(classList)  
  152.     return myTree  
# -*- coding: utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator



#计算数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #给所有可能分类创建字典
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    #以2为底数计算香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
    retDataSet=[]
    for featVec in dataSet:
        if direction==0:
            if featVec[axis]>value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis]<=value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    bestSplitDict={}
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        #对连续型特征进行处理
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
            #产生n-1个候选划分点
            sortfeatList=sorted(featList)
            splitList=[]
            for j in range(len(sortfeatList)-1):
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)
            
            bestSplitEntropy=10000
            slen=len(splitList)
            #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
            for j in range(slen):
                value=splitList[j]
                newEntropy=0.0
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
                prob0=len(subDataSet0)/float(len(dataSet))
                newEntropy+=prob0*calcShannonEnt(subDataSet0)
                prob1=len(subDataSet1)/float(len(dataSet))
                newEntropy+=prob1*calcShannonEnt(subDataSet1)
                if newEntropy<bestSplitEntropy:
                    bestSplitEntropy=newEntropy
                    bestSplit=j
            #用字典记录当前特征的最佳划分点
            bestSplitDict[labels[i]]=splitList[bestSplit]
            infoGain=baseEntropy-bestSplitEntropy
        #对离散型特征进行处理
        else:
            uniqueVals=set(featList)
            newEntropy=0.0
            #计算该特征下每种划分的信息熵
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value)
                prob=len(subDataSet)/float(len(dataSet))
                newEntropy+=prob*calcShannonEnt(subDataSet)
            infoGain=baseEntropy-newEntropy
        if infoGain>bestInfoGain:
            bestInfoGain=infoGain
            bestFeature=i
    #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
    #即是否小于等于bestSplitValue
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':      
        bestSplitValue=bestSplitDict[labels[bestFeature]]        
        labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature]<=bestSplitValue:
                dataSet[i][bestFeature]=1
            else:
                dataSet[i][bestFeature]=0
    return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    if type(dataSet[0][bestFeat]).__name__=='str':
        currentlabel=labels_full.index(labels[bestFeat])
        featValuesFull=[example[currentlabel] for example in data_full]
        uniqueValsFull=set(featValuesFull)
    del(labels[bestFeat])
    #针对bestFeat的每个取值,划分出一个子树。
    for value in uniqueVals:
        subLabels=labels[:]
        if type(dataSet[0][bestFeat]).__name__=='str':
            uniqueValsFull.remove(value)
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
         (dataSet,bestFeat,value),subLabels,data_full,labels_full)
    if type(dataSet[0][bestFeat]).__name__=='str':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value]=majorityCnt(classList)
    return myTree

通过以下语句进行调用:

  1. df=pd.read_csv('watermelon_4_3.csv')  
  2. data=df.values[:,1:].tolist()  
  3. data_full=data[:]  
  4. labels=df.columns.values[1:-1].tolist()  
  5. labels_full=labels[:]  
  6. myTree=createTree(data,labels,data_full,labels_full)  
df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)

可以得到以下结果
>>> myTree
{'texture': {'distinct': {'density<=0.3815': {0: 1L, 1: 0L}}, 'little_blur': {'touch': {'hard_smooth': 0L, 'soft_stick': 1L}}, 'blur': 0L}}

以下为画图代码:
  1. import matplotlib.pyplot as plt  
  2. decisionNode=dict(boxstyle="sawtooth",fc="0.8")  
  3. leafNode=dict(boxstyle="round4",fc="0.8")  
  4. arrow_args=dict(arrowstyle="<-")  
  5.   
  6.   
  7. #计算树的叶子节点数量  
  8. def getNumLeafs(myTree):  
  9.     numLeafs=0  
  10.     firstStr=myTree.keys()[0]  
  11.     secondDict=myTree[firstStr]  
  12.     for key in secondDict.keys():  
  13.         if type(secondDict[key]).__name__=='dict':  
  14.             numLeafs+=getNumLeafs(secondDict[key])  
  15.         else: numLeafs+=1  
  16.     return numLeafs  
  17.   
  18. #计算树的最大深度  
  19. def getTreeDepth(myTree):  
  20.     maxDepth=0  
  21.     firstStr=myTree.keys()[0]  
  22.     secondDict=myTree[firstStr]  
  23.     for key in secondDict.keys():  
  24.         if type(secondDict[key]).__name__=='dict':  
  25.             thisDepth=1+getTreeDepth(secondDict[key])  
  26.         else: thisDepth=1  
  27.         if thisDepth>maxDepth:  
  28.             maxDepth=thisDepth  
  29.     return maxDepth  
  30.   
  31. #画节点  
  32. def plotNode(nodeTxt,centerPt,parentPt,nodeType):  
  33.     createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\  
  34.     xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\  
  35.     bbox=nodeType,arrowprops=arrow_args)  
  36.   
  37. #画箭头上的文字  
  38. def plotMidText(cntrPt,parentPt,txtString):  
  39.     lens=len(txtString)  
  40.     xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002  
  41.     yMid=(parentPt[1]+cntrPt[1])/2.0  
  42.     createPlot.ax1.text(xMid,yMid,txtString)  
  43.       
  44. def plotTree(myTree,parentPt,nodeTxt):  
  45.     numLeafs=getNumLeafs(myTree)  
  46.     depth=getTreeDepth(myTree)  
  47.     firstStr=myTree.keys()[0]  
  48.     cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)  
  49.     plotMidText(cntrPt,parentPt,nodeTxt)  
  50.     plotNode(firstStr,cntrPt,parentPt,decisionNode)  
  51.     secondDict=myTree[firstStr]  
  52.     plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD  
  53.     for key in secondDict.keys():  
  54.         if type(secondDict[key]).__name__=='dict':  
  55.             plotTree(secondDict[key],cntrPt,str(key))  
  56.         else:  
  57.             plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW  
  58.             plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)  
  59.             plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))  
  60.     plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD  
  61.   
  62. def createPlot(inTree):  
  63.     fig=plt.figure(1,facecolor='white')  
  64.     fig.clf()  
  65.     axprops=dict(xticks=[],yticks=[])  
  66.     createPlot.ax1=plt.subplot(111,frameon=False,**axprops)  
  67.     plotTree.totalW=float(getNumLeafs(inTree))  
  68.     plotTree.totalD=float(getTreeDepth(inTree))  
  69.     plotTree.x0ff=-0.5/plotTree.totalW  
  70.     plotTree.y0ff=1.0  
  71.     plotTree(inTree,(0.5,1.0),'')  
  72.     plt.show()  
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else: numLeafs+=1
    return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
    xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
    bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
    lens=len(txtString)
    xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
    yMid=(parentPt[1]+cntrPt[1])/2.0
    createPlot.ax1.text(xMid,yMid,txtString)
    
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
            plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
    plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.x0ff=-0.5/plotTree.totalW
    plotTree.y0ff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

调用方式为
  1. createPlot(myTree)  
createPlot(myTree)

以上的决策树计算代码以及画图代码可以放在不同的文件中进行调用,也可以直接放在一个py文件中。

得到的决策树如下图所示:


与 机器学习 教材P85页的图一致。

若文中或代码中有错误之处,烦请指正,不甚感激。

更新:

2016.4.3对决策树生成的createTree函数进行了更新(上文代码已经是更新后的代码)。

原来的代码为:
  1. #主程序,递归产生决策树    
  2. def createTree(dataSet,labels):    
  3.     classList=[example[-1for example in dataSet]    
  4.     if classList.count(classList[0])==len(classList):    
  5.         return classList[0]    
  6.     if len(dataSet[0])==1:    
  7.         return majorityCnt(classList)    
  8.     bestFeat=chooseBestFeatureToSplit(dataSet,labels)    
  9.     bestFeatLabel=labels[bestFeat]    
  10.     myTree={bestFeatLabel:{}}    
  11.     del(labels[bestFeat])    
  12.     featValues=[example[bestFeat] for example in dataSet]    
  13.     uniqueVals=set(featValues)    
  14.     #针对bestFeat的每个取值,划分出一个子树。    
  15.     for value in uniqueVals:    
  16.         subLabels=labels[:]    
  17.         myTree[bestFeatLabel][value]=createTree(splitDataSet\    
  18.          (dataSet,bestFeat,value),subLabels)    
  19.     return myTree   
#主程序,递归产生决策树  
def createTree(dataSet,labels):  
    classList=[example[-1] for example in dataSet]  
    if classList.count(classList[0])==len(classList):  
        return classList[0]  
    if len(dataSet[0])==1:  
        return majorityCnt(classList)  
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)  
    bestFeatLabel=labels[bestFeat]  
    myTree={bestFeatLabel:{}}  
    del(labels[bestFeat])  
    featValues=[example[bestFeat] for example in dataSet]  
    uniqueVals=set(featValues)  
    #针对bestFeat的每个取值,划分出一个子树。  
    for value in uniqueVals:  
        subLabels=labels[:]  
        myTree[bestFeatLabel][value]=createTree(splitDataSet\  
         (dataSet,bestFeat,value),subLabels)  
    return myTree 

比如颜色有 dark_green, black, light_white 三种, 纹理有 distinct,little_blur, blur 这几种。若先按照纹理进行划分,则划分出distinct的子样本集中的颜色就没有light_white这个取值了。这使得得到的决策树在遇到新数据时可能无法进行决策(比如一个 texture:distinct; color:light_white的西瓜)。因此在递归的时候需要传递完整的训练数据集。从而产生完整的决策树。(缺失取值的类别划分选择当前数据集的多数类别(投票法))

如使用书上的表4.2(就是前面表格去掉密度和含糖量这两行)。使用之前代码得到的图为


以下为修改后的结果图,与书上P78的图4.4一致
可以看出,修改后左侧colo特征的划分是完整的
  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值