第三章 决策树

本文深入介绍了决策树的概念、构造过程和优缺点,包括如何使用信息增益选择最佳划分特征,递归构建决策树,并使用`Matplotlib`绘制树形图。此外,还讨论了测试和存储分类器的方法,以及应用决策树预测隐形眼镜类型的案例。
摘要由CSDN通过智能技术生成

决策树:本章内容

  • 决策树简介
  • 在数据集中度量一致性
  • 使用递归构造决策树
  • 使用Matplotlib绘制树形图

一、决策树简介

  • 如下图所示的流程图就是一个决策树。正方形代表判断模块;椭圆形代表终止模块,可以已经得出结论,可以终止运行。从判断模块引出的左右箭头称作分支,他可以到达另一个判断模块或者终止模块
  • 第二章介绍的k近邻可以完成很多的分类任务,但是其最大的缺点是无法给出数据的内在含义,决策树的主要优势在于数据形式非常容易理解

image-20210602153805671

二、决策树的构造

  • 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据

  • 缺点:可能会产生过度匹配问题

  • 使用数据类型:数值型和标称型

  • 在构造决策树时,需要解决的第一个问题是:当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分最好的结果,我们必须评估每个特征。完成测试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则当前分类无需进一步分割;如果数据子集内的数据不属于同一类型,则需要重复划分数据子集的过程。如何划分数据子集的算法和划分原始数据集的方法相同,直到所有具有相同类型的数据均在一个数据子集内

  • 创建分支的伪代码如下:

检测数据集中的每个子项是否属于同一个分支createBranch:
	if so return 类标签;
	else
		寻找划分数据集的最好特征
		划分数据集
		创建分支节点
			for 每个划分的子集
				递归调用函数createBranch并增加返回结果到分支节点中
		return 分支节点
  • 决策树的一般流程
    • 收集数据:
    • 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
    • 分析数据:构造树之后,检查图形是否符合预期
    • 训练算法:构造树的数据结构
    • 测试算法:使用经验树计算错误率
    • 使用算法:
划分数据
  • 二分法
  • ID3算法
信息增益
  • 划分数据的大原则是:将无序的数据数据变得更加有序。

  • 组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支学科。可以在划分数据之前使用信息论量化度量信息的内容

  • 在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,就可以计算每个特征值划分数据集获得的信息增益。获得信息增益最高的特征就是最好的选择

  • 集合信息的度量方式称为香农熵或者简称为熵

    • 熵定义为信息的期望值,在明晰这个概念之前,需要知道如何定义信息
      • 如果待分类的事务可能划分在多个分类中,则符号 x i {x_i} xi的信息定义为: l ( x i ) = − l o g 2 p ( x i ) l(x_i)=-log_2p(x_i) l(xi)=log2p(xi),其中 p ( x i ) {p(x_i)} p(xi)是选择该分类的概率
    • 为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式计算得到: H = − ∑ i = 1 n p ( x i ) l o g 2 p ( x i ) {H=-\sum_{i=1}^np(x_i)log_2p(x_i)} H=i=1np(xi)log2p(xi)。其中n是分类的数目
  • 使用Python计算信息熵

from math import *

# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    for featVec in dataSet: # 为所有可能分类创建字典
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0
    for key in labelCounts:
        prob=float(labelCounts[key]*1.0/numEntries)
        shannonEnt-=prob*log(prob,2)
    return shannonEnt

def createDataSet():
    dataSet=[
        [1,1,'yes'],
        [1,1,'yes'],
        [1,0,'no'],
        [0,1,'no'],
        [0,1,'no']
    ]
    labels=['no surfacing','flippers']
    return dataSet,labels


if __name__ == '__main__':
    # 测试1:计算香农熵
    myDat,labels=createDataSet()
    shannoEnt=calcShannonEnt(myDat)
    print(shannoEnt)
  • 熵越高,则混合的数据越多
划分数据集
  • 分类算法除了需要测量信息熵,还需要划分数据集,度量花费数据集的熵,以便判断当前是否是正确的划分了数据集。
  • 我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式
def splitDataSet(dataSet,axis,value):
    '''

    :param dataSet: 待划分的数据集
    :param axis: 划分数据集的特征
    :param value: 特征的返回值
    :return:
    '''
    retDataSet=[] #创建新的list对象
    for featVec in dataSet:
        if(featVec[axis]==value): #抽取:当我们按照某个特征划分数据集时,就需要将所有符合要去的元素抽取出来
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

  • 选择最好的方式划分数据集
def chooseBestFeatureToSplit(dataSet):
    '''
    实现选取特征,划分数据集,计算得出最好的划分数据集的特征
    该函数使用对数据的格式有一定的要求:
    :param dataSet:
    :return:
    '''
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)  #创建唯一的分类标签列表
        newEntropy=0.0
        for value in uniqueVals: #计算每种划分方式的信息熵
            subDataSet=splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntropy+=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        if(infoGain>bestInfoGain): # 计算最好的信息增益
            bestInfoGain=infoGain
            bestFeature=i
        print("使用特征:{} 进行划分,信息熵为:{}".format(i,newEntropy))
    return bestFeature
    
    
print(chooseBestFeatureToSplit(myDat))
print(myDat)
  • 上面的代码实现了选取特征、划分数据集,计算得出最好的划分数据集的特征。在函数中调用的数据需要满足一定的需求:第一个要去是,数据必须是一种由列表元素组成的列表,而且所有的列表元素都要具有相同的数据长度;第二个要求是,数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。数据集一旦满足上述要求,我们就可以在函数的第一行判定当前数据集包含多少特征属性

  • 上面的代码中,开始部分计算了原始数据的香农熵,保存最初的无序度量值,用于与划分之后的数据集计算的熵值进行比较

  • 第一个for循环遍历数据集中的所有特征。将数据集中所有第i个特征值或者所有可能存在的值写入新的list

  • 遍历当前特征中的所有唯一属性,对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一的特征值取得的熵求和。信息增益是熵的减少或者数据无序度的减少。

  • 最后,比较所有特征中的信息增益,返回最好特征划分的索引值

  • 针对以上的数据和划分,其运行结果为

0
递归构建决策树
  • 工作原理为:
    • 得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点。在这个节点上,我们可以再次划分数据。因此可以采用递归的原则处理数据集
    • 递归终结的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的每个实例都具有相同的分类。如果所有的实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类
  • 对于类标签不是唯一的,此时需要决定如何定义该叶子节点,这种情况下,通常会采用多数表决(投票)的方法决定该叶子节点的分类
def majorityCnt(classList):
    '''
    对于类标签不是唯一的,此时需要决定如何定义该叶子节点,这种情况下,通常会采用多数表决的方法决定该叶子节点的分类
    :param classList:
    :return:
    '''
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
  • 创建树的代码如下:
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0] #类别完全相同则停止继续划分
    if len(dataSet[0])==1:
        return majorityCnt(classList) #遍历完所有特征时返回出现次数最多的
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

myTree=createTree(myDat,labels)
print(myTree)
  • 上面创建树的代码中,列表变量classList包含了数据集的所有类标签。递归函数的第一个停止条件是所有的类标签完全相同,则直接返回该类标签;递归函数的第二个停止条件是使用完了所有特征,仍然不能将数据集划分为仅包含唯一类别的分组。由于第二个条件无法简单的返回唯一的类标签,这里所使用上面写过的投票的方法选择出现次数最多的类别作为返回值
  • 下一步程序开始创建树,使用字典类型存储树
使用Matplotlib注解绘制树形图
  • 使用Matplotlib的注解功能绘制树形图
# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
'''定义文本框和箭头格式'''
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

'''绘制带箭头的注解'''
def plotNode(nodeText,centerPt,parentPt,nodeType):
    createPlot.axl.annotate(nodeText,xy=parentPt,xycoords="axes fraction",xytext=centerPt,textcoords="axes fraction",
                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

def createPlot():
    fig=plt.figure(1,facecolor="white")
    fig.clf()
    createPlot.axl=plt.subplot(111,frameon=False)
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('a deaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

createPlot()    

代码运行结果如下图所示:

image-20210605202852270

构造注解树
  • 构造一棵完整的树,必须知道有多少了叶节点,以便可以正确确定x轴的长度;需要知道树的高度,以便可以正确确定y轴的高度
# 构造注解树
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:numLeafs+=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees=[
        {'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
        {'no surfacing':{0:'no',1:{'flippers':{0:{"head":{0:'no',1:'yes'}},1:'no'}}}},

    ]
    return listOfTrees[i]


def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict =myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig=plt.figure(1,facecolor="white")
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.axl=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()
    
    
myTree=retrieveTree(0)
createPlot(myTree)
myTree['no surfacing'][3]='maybe'
print(myTree)
createPlot(myTree)    

image-20210606143405800

image-20210606143430116

三、测试和存储分类器

测试算法:使用决策树执行分类
  • 依靠训练数据构造了决策树以后,我们就可以将其用于实际数据的分类。在执行数据分类时,需要决策树以及用于构造树的标签向量。然后程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点;最后将测试数据定义为叶子节点所属类型

  • 下面代码为使用决策树的分类函数

def classify(InputTree,featLabels,testVec):
    firstStr=list(InputTree.keys())[0]
    secondDict=InputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:classLabel=secondDict[key]
    return classLabel

myDat,labels=createDataSet()
print(labels)
myTree=retrieveTree(0)
print(myTree)
res=classify(myTree,labels,[1,0])
print(res)
print(classify(myTree,labels,[1,1]))
使用算法:决策树的存储
  • 构造决策树是非常耗时的任务,因此构造好决策树以后,需要序列化保存到磁盘中
  • 使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr=open(filename,"rb")
    return pickle.load(fr)

四、示例:使用决策树预测隐形眼镜类型

  • 隐形眼镜数据集包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型
  • 加载数据和可视化决策树
def loadLenses(filename):
    fr=open(filename)
    lenses=[inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels=['age','prescript','astigmatic','tearRate']
    lensesTree=createTree(lenses,lensesLabels)
    print(lensesTree)
    createPlot(lensesTree)
    
    
loadLenses('lenses.txt')    

image-20210606150901488

五、附录:代码

from math import *
import operator
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    for featVec in dataSet: # 为所有可能分类创建字典
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0
    for key in labelCounts:
        prob=float(labelCounts[key]*1.0/numEntries)
        shannonEnt-=prob*log(prob,2)
    return shannonEnt

def createDataSet():
    dataSet=[
        [1,1,'yes'],
        [1,1,'yes'],
        [1,0,'no'],
        [0,1,'no'],
        [0,1,'no']
    ]
    labels=['no surfacing','flippers']
    return dataSet,labels

def splitDataSet(dataSet,axis,value):
    '''

    :param dataSet: 待划分的数据集
    :param axis: 划分数据集的特征
    :param value: 特征的返回值
    :return:
    '''
    retDataSet=[] #创建新的list对象
    for featVec in dataSet:
        if(featVec[axis]==value): #抽取:当我们按照某个特征划分数据集时,就需要将所有符合要去的元素抽取出来
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    '''
    实现选取特征,划分数据集,计算得出最好的划分数据集的特征
    该函数使用对数据的格式有一定的要求:
    :param dataSet:
    :return:
    '''
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)  #创建唯一的分类标签列表
        newEntropy=0.0
        for value in uniqueVals: #计算每种划分方式的信息熵
            subDataSet=splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntropy+=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        if(infoGain>bestInfoGain): # 计算最好的信息增益
            bestInfoGain=infoGain
            bestFeature=i
        print("使用特征:{} 进行划分,信息熵为:{}".format(i,newEntropy))
    return bestFeature

def majorityCnt(classList):
    '''
    对于类标签不是唯一的,此时需要决定如何定义该叶子节点,这种情况下,通常会采用多数表决的方法决定该叶子节点的分类
    :param classList:
    :return:
    '''
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0] #类别完全相同则停止继续划分
    if len(dataSet[0])==1:
        return majorityCnt(classList) #遍历完所有特征时返回出现次数最多的
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
'''定义文本框和箭头格式'''
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

'''绘制带箭头的注解'''
def plotNode(nodeText,centerPt,parentPt,nodeType):
    createPlot.axl.annotate(nodeText,xy=parentPt,xycoords="axes fraction",xytext=centerPt,textcoords="axes fraction",
                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

def createPlot():
    fig=plt.figure(1,facecolor="white")
    fig.clf()
    createPlot.axl=plt.subplot(111,frameon=False)
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('a deaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()



# 构造注解树
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:numLeafs+=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees=[
        {'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
        {'no surfacing':{0:'no',1:{'flippers':{0:{"head":{0:'no',1:'yes'}},1:'no'}}}},

    ]
    return listOfTrees[i]


def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict =myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig=plt.figure(1,facecolor="white")
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.axl=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

def classify(InputTree,featLabels,testVec):
    firstStr=list(InputTree.keys())[0]
    secondDict=InputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:classLabel=secondDict[key]
    return classLabel

def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr=open(filename,"rb")
    return pickle.load(fr)

def loadLenses(filename):
    fr=open(filename)
    lenses=[inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels=['age','prescript','astigmatic','tearRate']
    lensesTree=createTree(lenses,lensesLabels)
    print(lensesTree)
    createPlot(lensesTree)
if __name__ == '__main__':
    # 测试1:计算香农熵
    # myDat,labels=createDataSet()
    # shannoEnt=calcShannonEnt(myDat)
    # print(shannoEnt)

    # 测试2:测试划分数据集划分
    # splitdata=splitDataSet(myDat,0,1)
    # print(splitdata)

    #测试3:测试选择最好的数据划分方式
    # print(chooseBestFeatureToSplit(myDat))
    # print(myDat)

    #测试4:创建树
    # myTree=createTree(myDat,labels)
    # print(myTree)

    # 测试绘制树
    # createPlot()

    # 测试get
    # mytrees=retrieveTree(0)
    # print(mytrees)
    # print(getNumLeafs(mytrees))
    # print(getTreeDepth(mytrees))

    # 测试绘制完整的树
    # myTree=retrieveTree(0)
    # createPlot(myTree)
    # myTree['no surfacing'][3]='maybe'
    # print(myTree)
    # createPlot(myTree)

    # 测试算法
    # myDat,labels=createDataSet()
    # print(labels)
    # myTree=retrieveTree(0)
    # print(myTree)
    # res=classify(myTree,labels,[1,0])
    # print(res)
    # print(classify(myTree,labels,[1,1]))

    #测试保存树
    # storeTree(myTree,"classify.txt")
    # print(grabTree("classify.txt"))

    #测试示例
    loadLenses('lenses.txt')

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值