(转)机器学习实战第三章——决策树(源码解析)

转载自:http://blog.csdn.net/quincuntial/article/details/50477508
创建树

#coding=utf-8  
''''' 
Created on 2016年1月5日 


@author: ltc 
'''  
from math import log  
import operator  
from ScrolledText import example  


# 计算信息熵  
def CalcShannonEnt(dataSet):  
    #计算数据集的输入个数  
    numEntries = len(dataSet)  
    #[]列表,{}元字典,()元组  
    # 创建存储标签的元字典  
    labelCounts = {}  
    #对数据集dataSet中的每一行featVec进行循环遍历  
    for featVec in dataSet:  
        # currentLabels为featVec的最后一个元素  
        currentLabels =featVec[-1]  
        # 如果标签currentLabels不在元字典对应的key中  
        if currentLabels not in labelCounts.keys():  
            # 将标签currentLabels放到字典中作为key,并将值赋为0  
            labelCounts[currentLabels] = 0  
        # 将currentLabels对应的值加1  
        labelCounts[currentLabels] += 1  
    # 定义香农熵shannonEnt  
    shannonEnt = 0.0  
    # 遍历元字典labelCounts中的key,即标签  
    for key in labelCounts:  
        # 计算每一个标签出现的频率,即概率  
        prob = float(labelCounts[key])/numEntries  
        # 根据信息熵公式计算每个标签信息熵并累加到shannonEnt上  
        shannonEnt -= prob*log(prob,2)  
    # 返回求得的整个标签对应的信息熵  
    return shannonEnt  


# 创建数据集  
def createDataSet():  
    dataSet = [[1,1,'yes'],  
               [1,1,'yes'],  
               [1,0,'no'],  
               [0,1,'no'],  
               [0,1,'no']]  
    labels=['no surfacing','flippers']  
    return dataSet,labels  


# 分割数据集  
# dataSet数据集,axis是对应的要分割数据的列,value是要分割的列按哪个值分割,即找到含有该值的数据  
def splitDataSet(dataSet,axis,value):  
    # 定义要返回的数据集  
    retDataSet = []  
    # 遍历数据集中的每个特征,即输入数据  
    for featVec in dataSet:  
        # 如果列标签对应的值为value,则将该条(行)数据加入到retDataSet中  
        if featVec[axis] == value:  
            # 取featVec的0-axis个数据,不包括axis,放到reducedFeatVec中  
            reducedFeatVec = featVec[:axis]  
            # 取featVec的axis+1到最后的数据,放到reducedFeatVec的后面  
            reducedFeatVec.extend(featVec[axis+1:])  
            # 将reducedFeatVec添加到分割后的数据集retDataSet中,同时reducedFeatVec,retDataSet中没有了axis列的数据  
            retDataSet.append(reducedFeatVec)  
    # 返回分割后的数据集  
    return retDataSet  


# 选择使分割后信息增益最大的特征,即对应的列  
def chooseBestFeatureToSplit(dataSet):  
    # 获取特征的数目,从0开始,dataSet[0]是一条数据  
    numFeatures = len(dataSet[0]) - 1  
    # 计算数据集当前的信息熵  
    baseEntropy = CalcShannonEnt(dataSet)  
    # 定义最大的信息增益  
    bestInfoGain = 0.0  
    # 定义分割后信息增益最大的特征  
    bestFeature = -1  
    # 遍历特征,即所有的列,计算每一列分割后的信息增益,找出信息增益最大的列  
    for i in range(numFeatures):  
        # 取出第i列特征赋给featList  
        featList = [example[i] for example in dataSet]  
        # 将特征对应的值放到一个集合中,使得特征列的数据具有唯一性  
        uniqueVals = set(featList)  
        # 定义分割后的信息熵  
        newEntropy = 0.0  
        # 遍历特征列的所有值(值是唯一的,重复值已经合并),分割并计算信息增益  
        for value in uniqueVals:  
            # 按照特征列的每个值进行数据集分割  
            subDataSet = splitDataSet(dataSet, i, value)   
            # 计算分割后的每个子集的概率(频率)  
            prob = len(subDataSet) / float(len(dataSet))  
            # 计算分割后的子集的信息熵并相加,得到分割后的整个数据集的信息熵  
            newEntropy +=prob * CalcShannonEnt(subDataSet)  
        # 计算分割后的信息增益  
        infoGain = baseEntropy - newEntropy  
        # 如果分割后信息增益大于最好的信息增益  
        if(infoGain > bestInfoGain):  
            # 将当前的分割的信息增益赋值为最好信息增益  
            bestInfoGain = infoGain  
            # 分割的最好特征列赋为i  
            bestFeature = i  
    # 返回分割后信息增益最大的特征列  
    return bestFeature  


# 对类标签进行投票 ,找标签数目最多的标签  
def majorityCnt(classList):  
    # 定义标签元字典,key为标签,value为标签的数目  
    classCount = {}  
    # 遍历所有标签  
    for vote in classList:  
        #如果标签不在元字典对应的key中  
        if vote not in classCount.keys():  
            # 将标签放到字典中作为key,并将值赋为0  
            classCount[vote] = 0  
        # 对应标签的数目加1  
        classCount[vote] += 1  
    # 对所有标签按数目排序  
    sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)  
    # 返回数目最多的标签  
    return sortedClassCount[0][0]  


# 创建决策树  
def createTree(dataSet,labels):  
    # 将dataSet的最后一列数据(标签)取出赋给classList,classList存储的是标签列  
    classList = [example[-1] for example in dataSet]  
    # 判断是否所有的列的标签都一致  
    if classList.count(classList[0]) == len(classList):  
        # 直接返回标签列的第一个数据  
        return classList[0]  
    # 判断dataSet是否只有一条数据  
    if len(dataSet) == 1:  
        # 返回标签列数据最多的标签  
        return majorityCnt(classList)  
    # 选择一个使数据集分割后最大的特征列的索引  
    bestFeat = chooseBestFeatureToSplit(dataSet)  
    # 找到最好的标签  
    bestFeatLabel = labels[bestFeat]  
    # 定义决策树,key为bestFeatLabel,value为空  
    myTree = {bestFeatLabel:{}}  
    # 删除labels[bestFeat]对应的元素  
    del(labels[bestFeat])  
    # 取出dataSet中bestFeat列的所有值  
    featValues = [example[bestFeat] for example in dataSet]  
    # 将特征对应的值放到一个集合中,使得特征列的数据具有唯一性  
    uniqueVals = set(featValues)  
    # 遍历uniqueVals中的值  
    for value in uniqueVals:  
        # 子标签subLabels为labels删除bestFeat标签后剩余的标签  
        subLabels = labels[:]  
        # myTree为key为bestFeatLabel时的决策树  
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat, value), subLabels)  
    # 返回决策树  
    return myTree  


# 决策树分类函数  
def classify(inputTree,featLabels,testVec):  
    # 得到树中的第一个特征  
    firstStr = inputTree.keys()[0]  
    # 得到第一个对应的值  
    secondDict = inputTree[firstStr]  
    # 得到树中第一个特征对应的索引  
    featIndex = featLabels.index(firstStr)  
    # 遍历树  
    for key in secondDict.keys():  
        # 如果在secondDict[key]中找到testVec[featIndex]  
        if testVec[featIndex] == key:  
            # 判断secondDict[key]是否为字典  
            if type(secondDict[key]).__name__ == 'dict':  
                # 若为字典,递归的寻找testVec  
                classLabel = classify(secondDict[key], featLabels, testVec)  
            else:  
                # 若secondDict[key]为标签值,则将secondDict[key]赋给classLabel  
                classLabel = secondDict[key]  
    # 返回类标签  
    return classLabel  


# 决策树的序列化  
def storeTree(inputTree,filename):  
    # 导入pyton模块  
    import pickle  
    # 以写的方式打开文件  
    fw = open(filename,'w')  
    # 决策树序列化  
    pickle.dump(inputTree,fw)          
# 读取序列化的树          
def grabTree(filename):  
    import pickle  
    fr = open(filename)  
    # 返回读到的树  
    return pickle.load(fr)  

matplotlib绘制树

import matplotlib.pyplot  as plt  


# 定义决策树决策结果的属性,用字典来定义  
# 下面的字典定义也可写作 decisionNode={boxstyle:'sawtooth',fc:'0.8'}  
# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细  
decisionNode = dict(boxstyle="sawtooth",fc="0.8")  
# 定义决策树的叶子结点的描述属性  
leafNode = dict(boxstyle="round4",fc="0.8")  
# 定义决策树的箭头属性  
arrow_args = dict(arrowstyle="<-")  

# 绘制结点  
def plotNode(nodeTxt,centerPt,parentPt,nodeType):  
    # annotate是关于一个数据点的文本  
    # nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点  
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',\  
                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)  
''''' 
# 创建绘图 
def createPlot(): 
    # 类似于Matlab的figure,定义一个画布(暂且这么称呼吧),背景为白色 
    fig = plt.figure(1,facecolor='white') 
    # 把画布清空 
    fig.clf() 
    # createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图 
    # frameon表示是否绘制坐标轴矩形 
    createPlot.ax1 = plt.subplot(111,frameon=False) 
    # 绘制结点 
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode) 
    # 绘制结点 
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode) 
    plt.show() 
'''        
# 获得决策树的叶子结点数目  
def getNumLeafs(myTree):  
    # 定义叶子结点数目  
    numLeafs = 0  
    # 获得myTree的第一个键值,即第一个特征,分割的标签  
    firstStr = myTree.keys()[0]  
    # 根据键值得到对应的值,即根据第一个特征分类的结果  
    secondDict = myTree[firstStr]  
    # 遍历得到的secondDict  
    for key in secondDict.keys():  
        # 如果secondDict[key]为一个字典,即决策树结点  
        if type(secondDict[key]).__name__ == 'dict':  
            # 则递归的计算secondDict中的叶子结点数,并加到numLeafs上  
            numLeafs += getNumLeafs(secondDict[key])  
        # 如果secondDict[key]为叶子结点  
        else:  
            # 则将叶子结点数加1      
            numLeafs += 1  
    # 返回求的叶子结点数目  
    return numLeafs  

# 获得决策树的深度  
def getTreeDepth(myTree):  
    # 定义树的深度  
    maxDepth = 0  
    # 获得myTree的第一个键值,即第一个特征,分割的标签  
    firstStr = myTree.keys()[0]  
    # 根据键值得到对应的值,即根据第一个特征分类的结果  
    secondDict = myTree[firstStr]  
    for key in secondDict.keys():  
        # 如果secondDict[key]为一个字典  
        if type(secondDict[key]).__name__ == 'dict':  
            # 则当前树的深度等于1加上secondDict的深度,只有当前点为决策树点深度才会加1  
            thisDepth = 1 + getTreeDepth(secondDict[key])  
            # 如果secondDict[key]为叶子结点  
        else:  
            # 则将当前树的深度设为1      
            thisDepth = 1  
    # 如果当前树的深度比最大数的深度  
        if thisDepth > maxDepth:  
            maxDepth = thisDepth  
    # 返回树的深度  
    return maxDepth   

# 绘制中间文本  
def plotMidText(cntrPt,parentPt,txtString):  
    # 求中间点的横坐标  
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]  
    # 求中间点的纵坐标  
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]  
    # 绘制树结点  
    createPlot.ax1.text(xMid,yMid,txtString)  

# 绘制决策树  
def plotTree(myTree,parentPt,nodeTxt):  
    # 定义并获得决策树的叶子结点数  
    numLeafs = getNumLeafs(myTree)  
    #depth =   
    getTreeDepth(myTree)  
    # 得到第一个特征  
    firstStr = myTree.keys()[0]  
    # 计算坐标,x坐标为当前树的叶子结点数目除以整个树的叶子结点数再除以2,y为起点  
    cntrPt = (plotTree.xOff + (1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)  
    # 绘制中间结点,即决策树结点,也是当前树的根结点,这句话没感觉出有用来,注释掉照样建立决策树,理解浅陋了,理解错了这句话的意思,下面有说明  
    plotMidText(cntrPt, parentPt, nodeTxt)  
    # 绘制决策树结点  
    plotNode(firstStr,cntrPt,parentPt,decisionNode)  
    # 根据firstStr找到对应的值  
    secondDict = myTree[firstStr]  
    # 因为进入了下一层,所以y的坐标要变 ,图像坐标是从左上角为原点  
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
    # 遍历secondDict  
    for key in secondDict.keys():  
        # 如果secondDict[key]为一棵子决策树,即字典  
        if type(secondDict[key]).__name__ == 'dict':  
            # 递归的绘制决策树  
            plotTree(secondDict[key],cntrPt,str(key))  
        # 若secondDict[key]为叶子结点  
        else:  
            # 计算叶子结点的横坐标  
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
            # 绘制叶子结点  
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt, leafNode)  
            # 这句注释掉也不影响决策树的绘制,自己理解的浅陋了,这行代码是特征的值  
            plotMidText((plotTree.xOff,plotTree.yOff), cntrPt, str(key))  
    # 计算纵坐标  
    plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD  

def createPlot(inTree):  
    # 定义一块画布(画布是自己的理解)  
    fig = plt.figure(1,facecolor='white')  
    # 清空画布  
    fig.clf()  
    # 定义横纵坐标轴,无内容  
    axprops = dict(xticks=[],yticks=[])  
    # 绘制图像,无边框,无坐标轴  
    createPlot.ax1 = plt.subplot(111,frameon=True,**axprops)  
    # plotTree.totalW保存的是树的宽  
    plotTree.totalW = float(getNumLeafs(inTree))  
    # plotTree.totalD保存的是树的高  
    plotTree.totalD = float(getTreeDepth(inTree))  
    # 决策树起始横坐标  
    plotTree.xOff = - 0.5 / plotTree.totalW #从0开始会偏右  
    print  plotTree.xOff  
    # 决策树的起始纵坐标  
    plotTree.yOff = 1.0  
    # 绘制决策树  
    plotTree(inTree,(0.5,1.0),'')  
    # 显示图像  
    plt.show()  

# 预定义的树,用来测试  
def retrieveTree(i):  
    listOfTree = [{'no surfacing':{ 0:'no',1:{'flippers': \  
                       {0:'no',1:'yes'}}}},  
                   {'no surfacing':{ 0:'no',1:{'flippers': \  
                    {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}  
                  ]  
    return listOfTree[i]  
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值