第三章 决策树

#第三章 决策树

from math import log
import operator
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签
plt.rcParams['axes.unicode_minus']=False   #这两行需要手动设置


#计算给定数据集的香农熵
#H(x)=-∑ p(xi)logp(xi)
def calcShannonEnt(dataset):
    numEntries=len(dataset)
    labelCounts={}
    for featVec in dataset:
        #为所有可能分类创建字典
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
        shannonEnt=0.0
        for key in labelCounts:
            #求p(xi)
            prob=float(labelCounts[key])/numEntries
            #以2为底求对数
            shannonEnt-=prob*log(prob,2)
    return shannonEnt

def createDataSet():
    dataSet=[[1,1,'yes'],
             [1,1,'yes'],
             [1,0,'no'],
             [0,1,'no'],
             [0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels

#按照给定特征划分数据集
#                待划分的数据集 划分数据集的特征 需要返回的特征的值
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataset):
    numFeatures=len(dataset[0])-1
    print(dataset[0])
    #计算整个数据集的原始香农熵
    baseEntroy=calcShannonEnt(dataset)
    bestInfoGain=0.0 #初始化信息增益
    bestFeature=-1 #初始化最佳分类特征
    for i in range(numFeatures):
        #创建唯一的分类标签列表
        featList=[example[i] for example in dataset]
        uniqueVals=set(featList)#从列表中创建集合是Python语言得到列表中唯一元素值的最快方法
        newEntropy=0.0
        #计算每种划分方式的信息熵
        for value in uniqueVals:  #对当前列的每一个取值进行循环
            subDataSet=splitDataSet(dataset,i,value)
            prob=len(subDataSet)/float(len(dataset))
            newEntropy+=prob*calcShannonEnt(subDataSet) #计算当前列的信息熵
        infoGain=baseEntroy-newEntropy #计算当前列的信息增益
        #计算最好的信息增益
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain #选择最大信息增益
            bestFeature=i #最大信息增益所在列的索引值
    return bestFeature #返回最大信息增益所在列的索引值


def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1 #存储每个类标签出现的频率
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #进行排序
    return sortedClassCount[0][0] #出现次数最多的分类名称

#创建树的函数代码
#              数据集  标签列表
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet] #存储所有类标签
    #第一个停止条件是所有的类标签完全相同,则直接返回该类标签
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #第二个停止条件是使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组
    if len(dataSet[0])==1:
        return majorityCnt(classList) #挑选出现次数最多的类别作为返回值
    #创建树
    bestFeat=chooseBestFeatureToSplit(dataSet)#存储最大信息增益所在列的索引值
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    #遍历当前选择特征包含的所有属性值
    for value in uniqueVals:
        subLabels=labels[:] #复制类标签,并将其保存在新列表变量subLabels中,每次调用函数createTree()时不改变原创列表的内容,使用subLabels代表原始列表
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

#使用文本注解绘制树结点
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#绘制带箭头的注解
# def plotNode(nodeTxt,centerPt,parentPt,nodeType):
#     createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',
#                             va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

# def createPlot():
#     fig=plt.figure(1,facecolor='white')
#     fig.clf()
#     createPlot.ax1=plt.subplot(111,frameon=False)
#     plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
#     plotNode(U'叶节点',(0.8,0.1),(0.3,0.8),leafNode)
#     plt.show()

#获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        #测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else: numLeafs+=1
    return numLeafs

#获取树的层数
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict': #如果子节点是字典类型,则该节点也是一个判断节点,需要递归使用getNumLeafs()函数
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

#输出预先存储的树信息
def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

#plotTree函数
#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=dict(boxstyle='round4'),arrowprops = dict(arrowstyle = '<-'))


def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,
            plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict:
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),
                     cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

#决策树分类函数
def classify(inputTree,featLabels,testVec):
    firstStr=list(inputTree.keys())[0]
    print('inputTree.keys()',inputTree.keys())
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)  #使用index方法查找当前列表中第一个匹配firstStr变量的元素
    for key in secondDict:
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else: classLabel=secondDict[key]
    return classLabel

#决策树的存储
#使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb')
    pickle.dump(inputTree,fw,0)
    fw.close()

def grabTree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)

if __name__=='__main__':
    myDat,labels=createDataSet()
    # myDat[0][-1]='maybe'
    # print(myDat)
    # print(calcShannonEnt(myDat))
    # print(splitDataSet(myDat,0,1))
    # print(splitDataSet(myDat, 0, 0))
    # print(chooseBestFeatureToSplit(myDat))
    # print(myDat)
    # myTree=createTree(myDat,labels)
    # print(myTree)


    myTree=retrieveTree(0)

    # myTree['no surfacing'][3]='maybe'
    print(myTree)
    # print(getNumLeafs(myTree))
    # print(getTreeDepth(myTree))
    # # createPlot(myTree)
    # print(classify(myTree,labels,[1,0]))
    # print(classify(myTree,labels,[1,1]))

    print(storeTree(myTree,'../MLinAction_source/classifierStorage.txt'))
    grabTree('classifierStorage.txt')
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值