决策树—ID3(源码解析)

理论方面机器学习实战中讲的非常清楚,深入点的话在西瓜书可以参考,这里只把源码贴出来和学习中的一些困难。

这里主要主要是有这么几块:

  • 首先搞懂信息熵和其作用

  • 划分数据集

  • 递归构建决策树

  • Matplotlib注解绘制树形图

  • 测试和存储分类器

  • 示例:使用决策树预测隐形眼镜类型

构建一个决策树:

from math import log
import operator
#import pickle
#import tree_plot


# 自己建立的数据
def createDataSet():
    # 各元素也是列表
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers'] # 特征标签
    return dataSet, labels

# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    # 数据集中实例的总数
    numEntries = len(dataSet)
    # 频数统计
    labelCounts = {}
    for featVec in dataSet: 
        currentLabel = featVec[-1] # 最后一列作为标签
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 # 统计不同的标签数
    #print labelCounts
    shannonEnt = 0.0
    # 计算香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)  # 以2为底
    return shannonEnt


# 按照给定特征划分数据集,返回剩余特征
# 三个参数:待划分的数据集,划分的特征,需返回的特征的值
def splitDataSet(dataSet, axis, value): 
    # 创建新的列表,避免影响原数据
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value: # 提取特定特征中的特定值
            reducedFeatVec = featVec[:axis] # 得到axis列之前列的特征
            # 在处理多个列表时,append()是添加的列表元素,而extend()添加的是元素
            reducedFeatVec.extend(featVec[axis+1:]) # 得到axis列之后列的特征
            retDataSet.append(reducedFeatVec)
    return retDataSet  # 返回指定划分特征外的特征

# 选择最好的数据集划分方式,根据熵计算
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 留出来一列,把最后一列作为类别标签
    baseEntropy = calcShannonEnt(dataSet)  # 计算原始香浓熵,划分前
    bestInfoGain = 0.0; bestFeature = -1 # 赋初值
    for i in range(numFeatures):        # 迭代所有的特征
        # 使用列表推导来创建新的列表,featlist得到的是每列的特征元素
        featList = [example[i] for example in dataSet]
        print 'featList:',featList
        uniqueVals = set(featList)      # 转换为集合,以此确保其中的元素的唯一性
        print 'uniqueVals:',uniqueVals
        newEntropy = 0.0        
        for value in uniqueVals:
             # 每一列按照不重复的元素划分,返回剩余特征
            subDataSet = splitDataSet(dataSet, i, value) 
            print 'subDataSet:',subDataSet
            prob = len(subDataSet)/float(len(dataSet)) # 频率
            # 得到此次划分的熵,此处的prob和calcShannonEnt()中的prob不是同一种,第一个
            # 是实例数在整体数组的频率,第二个是部分数组中的标签频率,newEntropy求的是信息期望
            newEntropy += prob * calcShannonEnt(subDataSet)    
        infoGain = baseEntropy - newEntropy # 计算信息增益,即熵的减少
        print 'infogain:',infoGain
        if (infoGain > bestInfoGain):       #如果信息量减少,就把减少量作为基准
            bestInfoGain = infoGain          
            bestFeature = i   
    return bestFeature      # 返回信息信息增益最高的特征列

# 多数表决法决定叶节点分类
def majorityCnt(classList): # classList是分类名称的列表
    classCount={} # 存储每个类标签出现的频率
    for vote in classList:
        # 统计所有的不重复的key
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    # 对分类进行倒序排列
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),\
                              reverse=True)
    # 返回出现次数最多的名称
    return sortedClassCount[0][0]

# 创建树的函数代码
def createTree(dataSet,labels):  # labels存储的是特征标签
    classList = [example[-1] for example in dataSet] # 数据集的最后一列作为类标签列表
    print 'classList:',classList
    # 判断类别是否完全相同,通过查看类标签的第一个的数目
    if classList.count(classList[0]) == len(classList): 
        return classList[0]
    # 判断是否遍历完所有的特征,通过查看剩下的特征数是不是剩下一个
    if len(dataSet[0]) == 1: 
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet) # 找到当前数据集最好的特征的索引
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}  # 嵌套字典,得到一个当前最好的特征标签
    del(labels[bestFeat])  # 删除当前最好的特征标签。即划分时类别特征数目减少
    # 列表推到式得到数据集的最好特征标签的那列
    featValues = [example[bestFeat] for example in dataSet] 
    uniqueVals = set(featValues)  # 集合确保元素的唯一性
    for value in uniqueVals:
        subLabels = labels[:]   # 得到当前剩余的特征标签  
        # 递归调用
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,\
                                                  bestFeat, value),subLabels)
    return myTree       

# 创建自己的数据集
mydat,labels=createDataSet() # mydat,lables相当于全局变量
print calcShannonEnt(mydat)
mydat[0][-1]='maybe'
print calcShannonEnt(mydat)

print(splitDataSet(mydat,1,1))
print(splitDataSet(mydat,0,1))
print('bestFeature:',chooseBestFeatureToSplit(mydat))
print '.............................'
print(createTree(mydat,labels))

看运行的结果:

0.970950594455
1.37095059445
[[1, 'maybe'], [1, 'yes'], [0, 'no'], [0, 'no']]
[[1, 'maybe'], [1, 'yes'], [0, 'no']]
featList: [1, 1, 1, 0, 0]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no'], [1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no']]
infogain: 0.419973094022
featList: [1, 1, 0, 1, 1]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no'], [0, 'no']]
infogain: 0.170950594455
('bestFeature:', 0)
.............................
classList: ['maybe', 'yes', 'no', 'no', 'no']
featList: [1, 1, 1, 0, 0]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no'], [1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no']]
infogain: 0.419973094022
featList: [1, 1, 0, 1, 1]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no'], [0, 'no']]
infogain: 0.170950594455
classList: ['no', 'no']
classList: ['maybe', 'yes', 'no']
featList: [1, 1, 0]
uniqueVals: set([0, 1])
subDataSet: [['no']]
subDataSet: [['maybe'], ['yes']]
infogain: 0.918295834054
classList: ['no']
classList: ['maybe', 'yes']
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'maybe'}}}}

上述结果显示了算法运行的具体过程,最后一行得到了决策树的数据结构,用的是python的字典来存储的,可以看出是一种层级关系。

在函数createTree(dataSet,labels)中:

myTree = {bestFeatLabel:{}}  # 嵌套字典,得到一个当前最好的特征标签
# 递归调用
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,\
                                                  bestFeat, value),subLabels)

这两行我起初觉得myTree字典被重置,其实是递归的作用,我的理解是这样的
字典递归

还有就是递归终止条件的判断要注意。。

使用Matplotlib绘制树形图:

可能会用到的函数:
matplotlib(1)
matplotlib(2)
matplotlib(3)

# -*- coding: utf-8 -*-
"""
绘制树节点
Created on Thu Aug 10 10:37:02 2017
@author: LiLong
"""
#import decision_tree.py
import matplotlib.pyplot as plt



# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8") 
leafNode = dict(boxstyle="round4", fc="0.8") # 定义决策树的叶子结点的描述属性
arrow_args = dict(arrowstyle="<-") # 定义箭头属性,也可以是<->,效果就变成双箭头的了


# 绘制结点文本和指向
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #nodeTxt为要显示的文本,xytext是文本的坐标,
    #xy是注释点的坐标 ,nodeType是注释边框的属性,arrowprops连接线的属性
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 获取叶节点的数目    
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]  # 得到第一个键    
    secondDict = myTree[firstStr] # 得到第一个键对应的值    
    for key in secondDict.keys(): 
        # 测试节点的数据类型是否是字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key]) # 又是递归调用
        else:   numLeafs +=1
    return numLeafs # 返回叶节点数

# 获取树的层数(递归在此就像是一层一层的剥到最里面,然后再从里到外加起来)
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr] 
    for key in secondDict.keys(): #keys()函数得到的是key,是一个列表
        #print'key:',key
        # 测试节点的数据类型是否是字典,如果是字典说明是可以再分的,深度+1
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用,层层剥离字典
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

# 绘制中间文本的坐标和显示内容,即父子之间的填充文本
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  # 求中间点的横坐标  
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 
    # 绘制出来此文本
    createPlot.ax1.text(xMid, yMid, txtString)


# 绘制树形图
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 得到叶节点的数,宽
    print 'numLeafs:',numLeafs
    depth = getTreeDepth(myTree)  # 获得树的层数,高
    firstStr = myTree.keys()[0]    # 得到第一个划分的特征
    # 计算坐标
    print 'plotTree.xOff:',plotTree.xOff
    print 'plotTree.totalW:',plotTree.totalW
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
                plotTree.yOff)
    print 'cntrPt:',cntrPt 
    # cntrPt是刚计算的坐标,parentPt是父节点坐标,nodeTxt目前为空字符
    plotMidText(cntrPt, parentPt, nodeTxt) # 绘制连接线上的文本
    plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制树节点
    secondDict = myTree[firstStr] # 下一级字典,即下一层
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 纵坐标降低

    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':   # 如果是树节点
            plotTree(secondDict[key],cntrPt,str(key))  #递归,绘制      
        else:   #如果是一个叶节点,就绘制出来
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 定x坐标
            # secondDict[key]叶节点文本,(plotTree.xOff, plotTree.yOff)箭头指向的坐标
            # cntrPt注释(父节点)的坐标
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制文本及连线
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  # 绘制父子填充文本
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


# 预留树信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]


#  Axis为坐标轴,Label为坐标轴标注。Tick为刻度线,ax是坐标系区域 
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white') 
    fig.clf()
    # 横纵坐标轴的刻度线,应该为空,加上范围后,父子间的节点连线的填充文本位置错乱
    axprops = dict(xticks=[], yticks=[]) # {'xticks': [], 'yticks': []}
    # createPlot.ax1创建绘图区,无边框,无刻度值
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 
    #createPlot.ax1 = plt.subplot(111, frameon=False) 
    # 计算树形图的全局变量,用于计算树节点的摆放位置,将树绘制在中心位置
    plotTree.totalW = float(getNumLeafs(inTree)) # plotTree.totalW保存的是树的宽
    plotTree.totalD = float(getTreeDepth(inTree)) # plotTree.totalD保存的是树的高
    plotTree.xOff = -0.5/plotTree.totalW # 决策树起始横坐标
    plotTree.yOff = 1.0  # 决策树的起始纵坐标  
    plotTree(inTree, (0.5,1.0), '') # 绘制树形图
    plt.show() # 显示


mytree=retrieveTree(0)
getNumLeafs(mytree) 
getTreeDepth(mytree)
createPlot(mytree)

运行结果:

numLeafs: 3
plotTree.xOff: -0.166666666667
plotTree.totalW: 3.0
cntrPt: (0.5, 1.0)
numLeafs: 2
plotTree.xOff: 0.166666666667
plotTree.totalW: 3.0
cntrPt: (0.6666666666666666, 0.5)

这里写图片描述

决策树图的上方代码是算法过程中的一些参数变化,有助于理解。其中决策树绘制过程中坐标的计算有点复杂。。。

下面是一些简单的知识点:

  函数也是对象,给一个对象绑定一个属性就是这样的:
    def f():
        pass 
    f.a = 1
    print f.a
>>> os.getcwd()
'C:\\Users\\LiLong'
>>> os.chdir('C:\\Users\\LiLong\\Desktop\\decision_tree')
>>> os.getcwd()
'C:\\Users\\LiLong\\Desktop\\decision_tree'
>>> 

使用决策树分类并预测隐形眼镜类型

tree_plot.py

# -*- coding: utf-8 -*-
"""
绘制树节点
Created on Thu Aug 10 10:37:02 2017
@author: LiLong
"""
#import decision_tree.py
import matplotlib.pyplot as plt



# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8") 
leafNode = dict(boxstyle="round4", fc="0.8") # 定义决策树的叶子结点的描述属性
arrow_args = dict(arrowstyle="<-") # 定义箭头属性,也可以是<->,效果就变成双箭头的了


# 绘制结点文本和指向
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #nodeTxt为要显示的文本,xytext是文本的坐标,
    #xy是注释点的坐标 ,nodeType是注释边框的属性,arrowprops连接线的属性
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 获取叶节点的数目    
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]  # 得到第一个键    
    secondDict = myTree[firstStr] # 得到第一个键对应的值    
    for key in secondDict.keys(): 
        # 测试节点的数据类型是否是字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key]) # 又是递归调用
        else:   numLeafs +=1
    return numLeafs # 返回叶节点数

# 获取树的层数(递归在此就像是一层一层的剥到最里面,然后再从里到外加起来)
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr] 
    for key in secondDict.keys(): #keys()函数得到的是key,是一个列表
        #print'key:',key
        # 测试节点的数据类型是否是字典,如果是字典说明是可以再分的,深度+1
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用,层层剥离字典
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

# 绘制中间文本的坐标和显示内容,即父子之间的填充文本
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  # 求中间点的横坐标  
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 
    # 绘制出来此文本
    createPlot.ax1.text(xMid, yMid, txtString)


# 绘制树形图
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 得到叶节点的数,宽
    print 'numLeafs:',numLeafs
    depth = getTreeDepth(myTree)  # 获得树的层数,高
    firstStr = myTree.keys()[0]    # 得到第一个划分的特征
    # 计算坐标
    print 'plotTree.xOff:',plotTree.xOff
    print 'plotTree.totalW:',plotTree.totalW
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
                plotTree.yOff)
    #print 'cntrPt:',cntrPt 
    # cntrPt是刚计算的坐标,parentPt是父节点坐标,nodeTxt目前为空字符
    plotMidText(cntrPt, parentPt, nodeTxt) # 绘制连接线上的文本
    plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制树节点
    secondDict = myTree[firstStr] # 下一级字典,即下一层
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 纵坐标降低

    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':   # 如果是树节点
            plotTree(secondDict[key],cntrPt,str(key))  #递归,绘制      
        else:   #如果是一个叶节点,就绘制出来
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 定x坐标
            # secondDict[key]叶节点文本,(plotTree.xOff, plotTree.yOff)箭头指向的坐标
            # cntrPt注释(父节点)的坐标
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制文本及连线
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  # 绘制父子填充文本
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

# 预留树信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]  

#  Axis为坐标轴,Label为坐标轴标注。Tick为刻度线,ax是坐标系区域 
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white') 
    fig.clf()
    # 横纵坐标轴的刻度线,应该为空,加上范围后,父子间的节点连线的填充文本位置错乱
    axprops = dict(xticks=[], yticks=[]) # {'xticks': [], 'yticks': []}
    # createPlot.ax1创建绘图区,无边框,无刻度值
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 
    #createPlot.ax1 = plt.subplot(111, frameon=False) 
    # 计算树形图的全局变量,用于计算树节点的摆放位置,将树绘制在中心位置
    plotTree.totalW = float(getNumLeafs(inTree)) # plotTree.totalW保存的是树的宽
    plotTree.totalD = float(getTreeDepth(inTree)) # plotTree.totalD保存的是树的高
    plotTree.xOff = -0.5/plotTree.totalW # 决策树起始横坐标
    plotTree.yOff = 1.0  # 决策树的起始纵坐标  
    plotTree(inTree, (0.5,1.0), '') # 绘制树形图
    plt.show() # 显示

decision_tree.py

# coding=utf-8
from math import log
import operator
import pickle
import tree_plot  # 导入decision_tree.py

# 自己建立的数据
def createDataSet():
    # 各元素也是列表
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers'] # 特征标签
    return dataSet, labels

# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    # 数据集中实例的总数
    numEntries = len(dataSet)
    # 频数统计
    labelCounts = {}
    for featVec in dataSet: 
        currentLabel = featVec[-1] # 最后一列作为标签
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 # 统计不同的标签数
    #print labelCounts
    shannonEnt = 0.0
    # 计算香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)  # 以2为底
    return shannonEnt


# 按照给定特征划分数据集,返回剩余特征
# 三个参数:待划分的数据集,划分的特征,需返回的特征的值
def splitDataSet(dataSet, axis, value): 
    # 创建新的列表,避免影响原数据
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value: # 提取特定特征中的特定值
            reducedFeatVec = featVec[:axis] # 得到axis列之前列的特征
            # 在处理多个列表时,append()是添加的列表元素,而extend()添加的是元素
            reducedFeatVec.extend(featVec[axis+1:]) # 得到axis列之后列的特征
            retDataSet.append(reducedFeatVec)
    return retDataSet  # 返回指定划分特征外的特征

# 选择最好的数据集划分方式,根据熵计算
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 留出来一列,把最后一列作为类别标签
    baseEntropy = calcShannonEnt(dataSet)  # 计算原始香浓熵,划分前
    bestInfoGain = 0.0; bestFeature = -1 # 赋初值
    for i in range(numFeatures):        # 迭代所有的特征
        # 使用列表推导来创建新的列表,featlist得到的是每列的特征元素
        featList = [example[i] for example in dataSet]
        print 'featList:',featList
        uniqueVals = set(featList)      # 转换为集合,以此确保其中的元素的唯一性
        print 'uniqueVals:',uniqueVals
        newEntropy = 0.0        
        for value in uniqueVals:
             # 每一列按照不重复的元素划分,返回剩余特征
            subDataSet = splitDataSet(dataSet, i, value) 
            print 'subDataSet:',subDataSet
            prob = len(subDataSet)/float(len(dataSet)) # 频率
            # 得到此次划分的熵,此处的prob和calcShannonEnt()中的prob不是同一种,第一个
            # 是实例数在整体数组的频率,第二个是部分数组中的标签频率,newEntropy求的是信息期望
            newEntropy += prob * calcShannonEnt(subDataSet)    
        infoGain = baseEntropy - newEntropy # 计算信息增益,即熵的减少
        print 'infogain:',infoGain
        if (infoGain > bestInfoGain):       #如果信息量减少,就把减少量作为基准
            bestInfoGain = infoGain          
            bestFeature = i   
    return bestFeature      # 返回信息信息增益最高的特征列

# 多数表决法决定叶节点分类
def majorityCnt(classList): # classList是分类名称的列表
    classCount={} # 存储每个类标签出现的频率
    for vote in classList:
        # 统计所有的不重复的key
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    # 对分类进行倒序排列
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),\
                              reverse=True)
    # 返回出现次数最多的名称
    return sortedClassCount[0][0]

# 创建树的函数代码
def createTree(dataSet,labels):  # labels存储的是特征标签
    classList = [example[-1] for example in dataSet] # 数据集的最后一列作为类标签列表
    print 'classList:',classList
    # 判断类别是否完全相同,通过查看类标签的第一个的数目
    if classList.count(classList[0]) == len(classList): 
        return classList[0]
    # 判断是否遍历完所有的特征,通过查看剩下的特征数是不是剩下一个
    if len(dataSet[0]) == 1: 
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet) # 找到当前数据集最好的特征的索引
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}  # 嵌套字典,得到一个当前最好的特征标签
    del(labels[bestFeat])  # 删除当前最好的特征标签。即划分时类别特征数目减少
    # 列表推到式得到数据集的最好特征标签的那列
    featValues = [example[bestFeat] for example in dataSet] 
    uniqueVals = set(featValues)  # 集合确保元素的唯一性
    for value in uniqueVals:
        subLabels = labels[:]   # 得到当前剩余的特征标签  
        # 递归调用
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,\
                                                  bestFeat, value),subLabels)
    return myTree       


# 使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]   # 树的第一个键
    secondDict = inputTree[firstStr] # 第一个键对应的值(字典)
    featIndex = featLabels.index(firstStr) # 第一个键(特征)在特征列表中的索引
    print 'featIndex:',featIndex
    key = testVec[featIndex]  # key是相应特征对应测试列表中的的取值,也即是父子节点间的判断
    print 'key:',key
    valueOfFeat = secondDict[key] #
    print 'valueOfFeat:',valueOfFeat
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel


# 使用pickle模块储存决策树
def storeTree(inputTree,filename):
    with  open(filename,'w') as fw:    
        pickle.dump(inputTree,fw)

def grabTree(filename):
    with  open(filename,'r') as fr:
        return pickle.load(fr)

# 执行分类
mydat,labels=createDataSet() # mydat,lables相当于全局变量
myTree=tree_plot.retrieveTree(0) # 树字典
print classify(myTree,labels,[1,0])  # 输出预测类型

# 预测隐形眼镜类型
with open('lenses.txt','r') as fr:
    # '\t'是tab分隔符,得到的是数组[[],[]....]
    lenses=[inst.strip().split('\t') for inst in fr.readlines()] 
    lensesLabels=['age','prescript','astigmatic','tearRate']
    lensesTree=createTree(lenses,lensesLabels)
    #storeTree(lensesTree,'clf.txt')
    print 'load:',grabTree('clf.txt')
    tree_plot.createPlot(lensesTree)

运行结果:

no

load: {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}

这里写图片描述

由此得到了决策树。。

此处还有一个问题没有解决:就是

myTree=tree_plot.retrieveTree(0) # 树字典

数字典用的是写好的,也可以说是运行得到的树字典,但是间接的。
如果直接用运行得到的字典

# 执行分类'
mydat,labels=createDataSet() # mydat,lables相当于全局变量
#myTree=tree_plot.retrieveTree(0) # 树字典
myTree=createTree(dataSet,labels)
print classify(myTree,labels,[1,0])  # 输出预测类型

报错:

 featIndex = featLabel.index(str(firstStr)) # 第一个键(特征)在特征列表中的索引
ValueError: 'no surfacing' is not in list

这个问题还没解决。。

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值