ID3 算法案例

1. 决策树的基本认识

 

   决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对

   象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能

   的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅

   有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。接下来讲解ID3算法。

 

 

2. ID3算法介绍

 

   ID3算法是决策树的一种,它是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法

   即Iterative Dichotomiser 3迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个

   算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总

   是生成最小的树型结构,而是一个启发式算法。

 

   在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息

   增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍

   历可能的决策空间。

 

 

3. 信息熵与信息增益

 

   在信息增益中,重要性的衡量标准就是看特征能够为分类系统带来多少信息,带来的信息越多,该特征越

   重要。在认识信息增益之前,先来看看信息熵的定义

 

   这个概念最早起源于物理学,在物理学中是用来度量一个热力学系统的无序程度,而在信息学里面,熵

   是对不确定性的度量。在1948年,香农引入了信息熵,将其定义为离散随机事件出现的概率,一个系统越

   是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序

   化程度的一个度量。

 

   假如一个随机变量的取值为,每一种取到的概率分别是,那么

   的熵定义为

 

             

 

   意思是一个变量的变化情况可能越多,那么它携带的信息量就越大。

 

   对于分类系统来说,类别是变量,它的取值是,而每一个类别出现的概率分别是

 

             

 

   而这里的就是类别的总数,此时分类系统的熵就可以表示为

 

             

 

   以上就是信息熵的定义,接下来介绍信息增益

 

   信息增益是针对一个一个特征而言的,就是看一个特征,系统有它和没有它时的信息量各是多少,两者

   的差值就是这个特征给系统带来的信息量,即信息增益

案例:

本文使用的Python库包括

  • numpy
  • pandas
  • math
  • operator
  • matplotlib

本文所用的数据如下:
                  Idx 色泽 根蒂 敲声 纹理 脐部 触感                密度            含糖率               label
1 青绿 蜷缩 浊响 清晰 凹陷 硬滑 0.697 0.46 1
2 乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 0.774 0.376 1
3 乌黑 蜷缩 浊响 清晰 凹陷 硬滑 0.634 0.264 1
4 青绿 蜷缩 沉闷 清晰 凹陷 硬滑 0.608 0.318 1
5 浅白 蜷缩 浊响 清晰 凹陷 硬滑 0.556 0.215 1
6 青绿 稍蜷 浊响 清晰 稍凹 软粘 0.403 0.237 1
7 乌黑 稍蜷 浊响 稍糊 稍凹 软粘 0.481 0.149 1
8 乌黑 稍蜷 浊响 清晰 稍凹 硬滑 0.437 0.211 1
9 乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 0.666 0.091 0
10 青绿 硬挺 清脆 清晰 平坦 软粘 0.243 0.267 0
11 浅白 硬挺 清脆 模糊 平坦 硬滑 0.245 0.057 0
12 浅白 蜷缩 浊响 模糊 平坦 软粘 0.343 0.099 0
13 青绿 稍蜷 浊响 稍糊 凹陷 硬滑 0.639 0.161 0
14 浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 0.657 0.198 0
15 乌黑 稍蜷 浊响 清晰 稍凹 软粘 0.36 0.37 0
16 浅白 蜷缩 浊响 模糊 平坦 硬滑 0.593 0.042 0
17 青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 0.719 0.103 0

由于我没搞定matplotlib的中文输出,因此将中文字符全换成了英文,如下:
Idx color root knocks texture navel touch density sugar_ratio label
1 dark_green curl_up little_heavily distinct sinking hard_smooth 0.697 0.46 1
2 black curl_up heavily distinct sinking hard_smooth 0.774 0.376 1
3 black curl_up little_heavily distinct sinking hard_smooth 0.634 0.264 1
4 dark_green curl_up heavily distinct sinking hard_smooth 0.608 0.318 1
5 light_white curl_up little_heavily distinct sinking hard_smooth 0.556 0.215 1
6 dark_green little_curl_up little_heavily distinct little_sinking soft_stick 0.403 0.237 1
7 black little_curl_up little_heavily little_blur little_sinking soft_stick 0.481 0.149 1
8 black little_curl_up little_heavily distinct little_sinking hard_smooth 0.437 0.211 1
9 black little_curl_up heavily little_blur little_sinking hard_smooth 0.666 0.091 0
10 dark_green stiff clear distinct even soft_stick 0.243 0.267 0
11 light_white stiff clear blur even hard_smooth 0.245 0.057 0
12 light_white curl_up little_heavily blur even soft_stick 0.343 0.099 0
13 dark_green little_curl_up little_heavily little_blur sinking hard_smooth 0.639 0.161 0
14 light_white little_curl_up heavily little_blur sinking hard_smooth 0.657 0.198 0
15 black little_curl_up little_heavily distinct little_sinking soft_stick 0.36 0.37 0
16 light_white curl_up little_heavily blur even hard_smooth 0.593 0.042 0
17 dark_green curl_up heavily little_blur little_sinking hard_smooth 0.719 0.103 0

字符的含义可自行对照上下两表

决策树生成的代码参照 机器学习实战 第三章的代码,但是书上第三章是针对离散特征的,下面程序中对其进行了修改,使其能用于同时包含离散与连续特征的数据集。

决策树生成代码如下:

  1. # -*- coding: utf-8 -*-  
  2.   
  3. from numpy import *  
  4. import numpy as np  
  5. import pandas as pd  
  6. from math import log  
  7. import operator  
  8.   
  9.   
  10.   
  11. #计算数据集的香农熵  
  12. def calcShannonEnt(dataSet):  
  13.     numEntries=len(dataSet)  
  14.     labelCounts={}  
  15.     #给所有可能分类创建字典  
  16.     for featVec in dataSet:  
  17.         currentLabel=featVec[-1]  
  18.         if currentLabel not in labelCounts.keys():  
  19.             labelCounts[currentLabel]=0  
  20.         labelCounts[currentLabel]+=1  
  21.     shannonEnt=0.0  
  22.     #以2为底数计算香农熵  
  23.     for key in labelCounts:  
  24.         prob = float(labelCounts[key])/numEntries  
  25.         shannonEnt-=prob*log(prob,2)  
  26.     return shannonEnt  
  27.   
  28.   
  29. #对离散变量划分数据集,取出该特征取值为value的所有样本  
  30. def splitDataSet(dataSet,axis,value):  
  31.     retDataSet=[]  
  32.     for featVec in dataSet:  
  33.         if featVec[axis]==value:  
  34.             reducedFeatVec=featVec[:axis]  
  35.             reducedFeatVec.extend(featVec[axis+1:])  
  36.             retDataSet.append(reducedFeatVec)  
  37.     return retDataSet  
  38.   
  39. #对连续变量划分数据集,direction规定划分的方向,  
  40. #决定是划分出小于value的数据样本还是大于value的数据样本集  
  41. def splitContinuousDataSet(dataSet,axis,value,direction):  
  42.     retDataSet=[]  
  43.     for featVec in dataSet:  
  44.         if direction==0:  
  45.             if featVec[axis]>value:  
  46.                 reducedFeatVec=featVec[:axis]  
  47.                 reducedFeatVec.extend(featVec[axis+1:])  
  48.                 retDataSet.append(reducedFeatVec)  
  49.         else:  
  50.             if featVec[axis]<=value:  
  51.                 reducedFeatVec=featVec[:axis]  
  52.                 reducedFeatVec.extend(featVec[axis+1:])  
  53.                 retDataSet.append(reducedFeatVec)  
  54.     return retDataSet  
  55.   
  56. #选择最好的数据集划分方式  
  57. def chooseBestFeatureToSplit(dataSet,labels):  
  58.     numFeatures=len(dataSet[0])-1  
  59.     baseEntropy=calcShannonEnt(dataSet)  
  60.     bestInfoGain=0.0  
  61.     bestFeature=-1  
  62.     bestSplitDict={}  
  63.     for i in range(numFeatures):  
  64.         featList=[example[i] for example in dataSet]  
  65.         #对连续型特征进行处理  
  66.         if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':  
  67.             #产生n-1个候选划分点  
  68.             sortfeatList=sorted(featList)  
  69.             splitList=[]  
  70.             for j in range(len(sortfeatList)-1):  
  71.                 splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)  
  72.               
  73.             bestSplitEntropy=10000  
  74.             slen=len(splitList)  
  75.             #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点  
  76.             for j in range(slen):  
  77.                 value=splitList[j]  
  78.                 newEntropy=0.0  
  79.                 subDataSet0=splitContinuousDataSet(dataSet,i,value,0)  
  80.                 subDataSet1=splitContinuousDataSet(dataSet,i,value,1)  
  81.                 prob0=len(subDataSet0)/float(len(dataSet))  
  82.                 newEntropy+=prob0*calcShannonEnt(subDataSet0)  
  83.                 prob1=len(subDataSet1)/float(len(dataSet))  
  84.                 newEntropy+=prob1*calcShannonEnt(subDataSet1)  
  85.                 if newEntropy<bestSplitEntropy:  
  86.                     bestSplitEntropy=newEntropy  
  87.                     bestSplit=j  
  88.             #用字典记录当前特征的最佳划分点  
  89.             bestSplitDict[labels[i]]=splitList[bestSplit]  
  90.             infoGain=baseEntropy-bestSplitEntropy  
  91.         #对离散型特征进行处理  
  92.         else:  
  93.             uniqueVals=set(featList)  
  94.             newEntropy=0.0  
  95.             #计算该特征下每种划分的信息熵  
  96.             for value in uniqueVals:  
  97.                 subDataSet=splitDataSet(dataSet,i,value)  
  98.                 prob=len(subDataSet)/float(len(dataSet))  
  99.                 newEntropy+=prob*calcShannonEnt(subDataSet)  
  100.             infoGain=baseEntropy-newEntropy  
  101.         if infoGain>bestInfoGain:  
  102.             bestInfoGain=infoGain  
  103.             bestFeature=i  
  104.     #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理  
  105.     #即是否小于等于bestSplitValue  
  106.     if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':        
  107.         bestSplitValue=bestSplitDict[labels[bestFeature]]          
  108.         labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)  
  109.         for i in range(shape(dataSet)[0]):  
  110.             if dataSet[i][bestFeature]<=bestSplitValue:  
  111.                 dataSet[i][bestFeature]=1  
  112.             else:  
  113.                 dataSet[i][bestFeature]=0  
  114.     return bestFeature  
  115.   
  116. #特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票  
  117. def majorityCnt(classList):  
  118.     classCount={}  
  119.     for vote in classList:  
  120.         if vote not in classCount.keys():  
  121.             classCount[vote]=0  
  122.         classCount[vote]+=1  
  123.     return max(classCount)  
  124.   
  125. #主程序,递归产生决策树  
  126. def createTree(dataSet,labels,data_full,labels_full):  
  127.     classList=[example[-1for example in dataSet]  
  128.     if classList.count(classList[0])==len(classList):  
  129.         return classList[0]  
  130.     if len(dataSet[0])==1:  
  131.         return majorityCnt(classList)  
  132.     bestFeat=chooseBestFeatureToSplit(dataSet,labels)  
  133.     bestFeatLabel=labels[bestFeat]  
  134.     myTree={bestFeatLabel:{}}  
  135.     featValues=[example[bestFeat] for example in dataSet]  
  136.     uniqueVals=set(featValues)  
  137.     if type(dataSet[0][bestFeat]).__name__=='str':  
  138.         currentlabel=labels_full.index(labels[bestFeat])  
  139.         featValuesFull=[example[currentlabel] for example in data_full]  
  140.         uniqueValsFull=set(featValuesFull)  
  141.     del(labels[bestFeat])  
  142.     #针对bestFeat的每个取值,划分出一个子树。  
  143.     for value in uniqueVals:  
  144.         subLabels=labels[:]  
  145.         if type(dataSet[0][bestFeat]).__name__=='str':  
  146.             uniqueValsFull.remove(value)  
  147.         myTree[bestFeatLabel][value]=createTree(splitDataSet\  
  148.          (dataSet,bestFeat,value),subLabels,data_full,labels_full)  
  149.     if type(dataSet[0][bestFeat]).__name__=='str':  
  150.         for value in uniqueValsFull:  
  151.             myTree[bestFeatLabel][value]=majorityCnt(classList)  
  152.     return myTree  
# -*- coding: utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator



#计算数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #给所有可能分类创建字典
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    #以2为底数计算香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
    retDataSet=[]
    for featVec in dataSet:
        if direction==0:
            if featVec[axis]>value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis]<=value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    bestSplitDict={}
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        #对连续型特征进行处理
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
            #产生n-1个候选划分点
            sortfeatList=sorted(featList)
            splitList=[]
            for j in range(len(sortfeatList)-1):
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)
            
            bestSplitEntropy=10000
            slen=len(splitList)
            #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
            for j in range(slen):
                value=splitList[j]
                newEntropy=0.0
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
                prob0=len(subDataSet0)/float(len(dataSet))
                newEntropy+=prob0*calcShannonEnt(subDataSet0)
                prob1=len(subDataSet1)/float(len(dataSet))
                newEntropy+=prob1*calcShannonEnt(subDataSet1)
                if newEntropy<bestSplitEntropy:
                    bestSplitEntropy=newEntropy
                    bestSplit=j
            #用字典记录当前特征的最佳划分点
            bestSplitDict[labels[i]]=splitList[bestSplit]
            infoGain=baseEntropy-bestSplitEntropy
        #对离散型特征进行处理
        else:
            uniqueVals=set(featList)
            newEntropy=0.0
            #计算该特征下每种划分的信息熵
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value)
                prob=len(subDataSet)/float(len(dataSet))
                newEntropy+=prob*calcShannonEnt(subDataSet)
            infoGain=baseEntropy-newEntropy
        if infoGain>bestInfoGain:
            bestInfoGain=infoGain
            bestFeature=i
    #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
    #即是否小于等于bestSplitValue
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':      
        bestSplitValue=bestSplitDict[labels[bestFeature]]        
        labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature]<=bestSplitValue:
                dataSet[i][bestFeature]=1
            else:
                dataSet[i][bestFeature]=0
    return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    if type(dataSet[0][bestFeat]).__name__=='str':
        currentlabel=labels_full.index(labels[bestFeat])
        featValuesFull=[example[currentlabel] for example in data_full]
        uniqueValsFull=set(featValuesFull)
    del(labels[bestFeat])
    #针对bestFeat的每个取值,划分出一个子树。
    for value in uniqueVals:
        subLabels=labels[:]
        if type(dataSet[0][bestFeat]).__name__=='str':
            uniqueValsFull.remove(value)
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
         (dataSet,bestFeat,value),subLabels,data_full,labels_full)
    if type(dataSet[0][bestFeat]).__name__=='str':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value]=majorityCnt(classList)
    return myTree

通过以下语句进行调用:

  1. df=pd.read_csv('watermelon_4_3.csv')  
  2. data=df.values[:,1:].tolist()  
  3. data_full=data[:]  
  4. labels=df.columns.values[1:-1].tolist()  
  5. labels_full=labels[:]  
  6. myTree=createTree(data,labels,data_full,labels_full)  
df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)

可以得到以下结果
>>> myTree
{'texture': {'distinct': {'density<=0.3815': {0: 1L, 1: 0L}}, 'little_blur': {'touch': {'hard_smooth': 0L, 'soft_stick': 1L}}, 'blur': 0L}}

以下为画图代码:
  1. import matplotlib.pyplot as plt  
  2. decisionNode=dict(boxstyle="sawtooth",fc="0.8")  
  3. leafNode=dict(boxstyle="round4",fc="0.8")  
  4. arrow_args=dict(arrowstyle="<-")  
  5.   
  6.   
  7. #计算树的叶子节点数量  
  8. def getNumLeafs(myTree):  
  9.     numLeafs=0  
  10.     firstStr=myTree.keys()[0]  
  11.     secondDict=myTree[firstStr]  
  12.     for key in secondDict.keys():  
  13.         if type(secondDict[key]).__name__=='dict':  
  14.             numLeafs+=getNumLeafs(secondDict[key])  
  15.         else: numLeafs+=1  
  16.     return numLeafs  
  17.   
  18. #计算树的最大深度  
  19. def getTreeDepth(myTree):  
  20.     maxDepth=0  
  21.     firstStr=myTree.keys()[0]  
  22.     secondDict=myTree[firstStr]  
  23.     for key in secondDict.keys():  
  24.         if type(secondDict[key]).__name__=='dict':  
  25.             thisDepth=1+getTreeDepth(secondDict[key])  
  26.         else: thisDepth=1  
  27.         if thisDepth>maxDepth:  
  28.             maxDepth=thisDepth  
  29.     return maxDepth  
  30.   
  31. #画节点  
  32. def plotNode(nodeTxt,centerPt,parentPt,nodeType):  
  33.     createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\  
  34.     xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\  
  35.     bbox=nodeType,arrowprops=arrow_args)  
  36.   
  37. #画箭头上的文字  
  38. def plotMidText(cntrPt,parentPt,txtString):  
  39.     lens=len(txtString)  
  40.     xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002  
  41.     yMid=(parentPt[1]+cntrPt[1])/2.0  
  42.     createPlot.ax1.text(xMid,yMid,txtString)  
  43.       
  44. def plotTree(myTree,parentPt,nodeTxt):  
  45.     numLeafs=getNumLeafs(myTree)  
  46.     depth=getTreeDepth(myTree)  
  47.     firstStr=myTree.keys()[0]  
  48.     cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)  
  49.     plotMidText(cntrPt,parentPt,nodeTxt)  
  50.     plotNode(firstStr,cntrPt,parentPt,decisionNode)  
  51.     secondDict=myTree[firstStr]  
  52.     plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD  
  53.     for key in secondDict.keys():  
  54.         if type(secondDict[key]).__name__=='dict':  
  55.             plotTree(secondDict[key],cntrPt,str(key))  
  56.         else:  
  57.             plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW  
  58.             plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)  
  59.             plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))  
  60.     plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD  
  61.   
  62. def createPlot(inTree):  
  63.     fig=plt.figure(1,facecolor='white')  
  64.     fig.clf()  
  65.     axprops=dict(xticks=[],yticks=[])  
  66.     createPlot.ax1=plt.subplot(111,frameon=False,**axprops)  
  67.     plotTree.totalW=float(getNumLeafs(inTree))  
  68.     plotTree.totalD=float(getTreeDepth(inTree))  
  69.     plotTree.x0ff=-0.5/plotTree.totalW  
  70.     plotTree.y0ff=1.0  
  71.     plotTree(inTree,(0.5,1.0),'')  
  72.     plt.show()  
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else: numLeafs+=1
    return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
    xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
    bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
    lens=len(txtString)
    xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
    yMid=(parentPt[1]+cntrPt[1])/2.0
    createPlot.ax1.text(xMid,yMid,txtString)
    
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
            plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
    plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.x0ff=-0.5/plotTree.totalW
    plotTree.y0ff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

调用方式为
  1. createPlot(myTree)  
createPlot(myTree)

以上的决策树计算代码以及画图代码可以放在不同的文件中进行调用,也可以直接放在一个py文件中。

得到的决策树如下图所示:


与 机器学习 教材P85页的图一致。

若文中或代码中有错误之处,烦请指正,不甚感激。

更新:

2016.4.3对决策树生成的createTree函数进行了更新(上文代码已经是更新后的代码)。

原来的代码为:
  1. #主程序,递归产生决策树    
  2. def createTree(dataSet,labels):    
  3.     classList=[example[-1for example in dataSet]    
  4.     if classList.count(classList[0])==len(classList):    
  5.         return classList[0]    
  6.     if len(dataSet[0])==1:    
  7.         return majorityCnt(classList)    
  8.     bestFeat=chooseBestFeatureToSplit(dataSet,labels)    
  9.     bestFeatLabel=labels[bestFeat]    
  10.     myTree={bestFeatLabel:{}}    
  11.     del(labels[bestFeat])    
  12.     featValues=[example[bestFeat] for example in dataSet]    
  13.     uniqueVals=set(featValues)    
  14.     #针对bestFeat的每个取值,划分出一个子树。    
  15.     for value in uniqueVals:    
  16.         subLabels=labels[:]    
  17.         myTree[bestFeatLabel][value]=createTree(splitDataSet\    
  18.          (dataSet,bestFeat,value),subLabels)    
  19.     return myTree   
#主程序,递归产生决策树  
def createTree(dataSet,labels):  
    classList=[example[-1] for example in dataSet]  
    if classList.count(classList[0])==len(classList):  
        return classList[0]  
    if len(dataSet[0])==1:  
        return majorityCnt(classList)  
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)  
    bestFeatLabel=labels[bestFeat]  
    myTree={bestFeatLabel:{}}  
    del(labels[bestFeat])  
    featValues=[example[bestFeat] for example in dataSet]  
    uniqueVals=set(featValues)  
    #针对bestFeat的每个取值,划分出一个子树。  
    for value in uniqueVals:  
        subLabels=labels[:]  
        myTree[bestFeatLabel][value]=createTree(splitDataSet\  
         (dataSet,bestFeat,value),subLabels)  
    return myTree 

比如颜色有 dark_green, black, light_white 三种, 纹理有 distinct,little_blur, blur 这几种。若先按照纹理进行划分,则划分出distinct的子样本集中的颜色就没有light_white这个取值了。这使得得到的决策树在遇到新数据时可能无法进行决策(比如一个 texture:distinct; color:light_white的西瓜)。因此在递归的时候需要传递完整的训练数据集。从而产生完整的决策树。(缺失取值的类别划分选择当前数据集的多数类别(投票法))

如使用书上的表4.2(就是前面表格去掉密度和含糖量这两行)。使用之前代码得到的图为


以下为修改后的结果图,与书上P78的图4.4一致
可以看出,修改后左侧colo特征的划分是完整的
展开阅读全文

没有更多推荐了,返回首页