决策树_Python3实现代码及注释

一.决策树的构造及绘制整体代码

数据表:

							海洋生物数据
							
		不浮出水面是否可以生存		是否有脚蹼		属于鱼类		
1				是						是				是
2				是						是				是
3				是						否				否
4				否						是				否
5				否						是				否

							
		no surfacing  		flippers		fish		
1			1					1			 yes
2			1					1			 yes
3			1					0			 no
4			0					1			 no
5			0					1			 no

代码0:

from math import log

def createDataSet(): #构建数据集
    dataSet = [[1, 1, 'yes'],
              [1, 1, 'yes'],
              [1, 0, 'no'],
              [0, 1, 'no'],
              [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def calcShannonEnt(dataSet): #计算香农熵
    numEntries = len(dataSet) #数据集中数据的个数
    labelCounts = {} #记录各分类个数的字典
    for featVec in dataSet: #对数据集中的每个向量
        currentLabel = featVec[-1] #当前标记为当前向量的最后一位
        if currentLabel not in labelCounts.keys(): #如果当前标记尚未收录
            labelCounts[currentLabel] = 0 #初始化为0
        labelCounts[currentLabel] += 1 #已存在则计数加1
    shannonEnt = 0.0 #香农熵初始化为0
    for key in labelCounts: #对各种标记
        prob = float(labelCounts[key])/numEntries #把该标记出现的次数与总数的比看作概率
        shannonEnt -= prob*log(prob,2) # H = -∑P(xi)*logP(xi)
    return shannonEnt

def splitDataSet(dataSet, axis, value): #划分数据集 #待划分的数据集,划分数据集的特征,需要返回的特征的值
    retDataSet = [] #划分后的数据集列表
    for featVec in dataSet:
        if featVec[axis] == value: #如果当前数据此特征的值恰好为要返回的值
            reducedFeatVec = featVec[:axis] #axis前的元素
            reducedFeatVec.extend(featVec[axis+1:]) #axis之后的元素 #.extend()是把列表中元素逐个加入 #这两步即删除axis
            retDataSet.append(reducedFeatVec) #处理过的向量加入划分后的数据集列表 #.append()是把列表作为元素加入
    return retDataSet

def chooseBestFeatureToSplit(dataSet): #选择最好的方式划分数据集
    numFeatures = len(dataSet[0]) - 1 #特征的个数为向量长度减1 #减去标记项
    baseEntropy = calcShannonEnt(dataSet) #作为比较的基准熵为数据集本身的熵
    bestInfoGain = 0.0; bestFeature = -1 #最大信息增益初始化为0,最佳划分特征编号初始化为-1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet] #数据集中每个数据i号特征构成的列表
        uniqueVals = set(featList) #把列表转换为集合 #即剔除其中重复值
        newEntropy = 0.0 #用i号特征划分数据集得到的熵初始化为0
        for value in uniqueVals: #对i号特征各个不同的值
            subDataSet = splitDataSet(dataSet, i, value) #选出i号特征为当前值的数据
            prob = len(subDataSet)/float(len(dataSet)) #i号特征当前值出现的概率
            newEntropy += prob*calcShannonEnt(subDataSet) #以i号特征作为划分的熵加上当前值出现的概率与i号特征为当前值的数据集的熵的乘积
        infoGain = baseEntropy - newEntropy #以i号特征划分得到的信息增益 #即基准熵减去当前熵
        if infoGain>bestInfoGain: #如果当前信息增益大于当前最大信息增益
            bestInfoGain = infoGain #更新最大信息增益
            bestFeature = i #最佳划分方式特征编号更新为i
        return bestFeature

def majorityCnt(classList): #投票
    classCount = {} #存储各分类结果的票数
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=lambda item:item[1], reverse=True)
    #字典元素按键值倒序排序 #sorted(待排元素,排序关键字,是否倒序) #lambda item:item[1]取每个元素编号为1的位置,即value值
    return sortedClassCount[0][0] #输出排序后的第一个元素的key值 #即出现次数最多的标记值

def createTree(dataSet, labels): #创建树 #递归函数,数据集不断划分不断调用
    classList = [example[-1] for example in dataSet] #当前数据集所有的标记值
    if classList.count(classList[0])==len(classList): #递归结束条件之一:第一个标记值出现次数为列表长度,即所有类别完全相同
        return classList[0] #直接返回该标记值即可
    if len(dataSet[0])==1: #递归结束条件之二:数据集中数据向量只剩标记值一项,即所有特征都已用过但类别仍不完全相同 #若相同则之前已返回
        return majorityCnt(classList) #调用投票函数,返回投票结果
    bestFeat = chooseBestFeatureToSplit(dataSet) #最佳划分特征的编号
    bestFeatLabel = labels[bestFeat] #最佳划分特征
    myTree = {bestFeatLabel:{}} #第一个选中的特征作为根节点 #其后依然是树,故用字典嵌套
    del(labels[bestFeat]) #删除特征集中已用过的特征
    featValues = [example[bestFeat] for example in dataSet] #数据集中各个数据此特征值组成的列表
    uniqueVals = set(featValues) #集合化,去除相同值
    for value in uniqueVals:
        subLabels = labels[:] #特征集已删除了选中特征,把现在的特征集作为递归调用的特征集
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value), subLabels)
        #树中当前特征当前值对应的子树为递归调用数据集如此划分后创建的树
    return myTree

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth",fc="0.8") #决策节点的格式
leafNode = dict(boxstyle="round4",fc="0.8") #叶节点的格式
arrow_args = dict(arrowstyle="<-") #箭头格式

def plotNode(nodeTxt, centerPt, parentPt, nodeType): #节点文本,节点坐标,父节点坐标,节点类型
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) #绘制带箭头的注解 
    #关于pyplot.annatate()可参考https://blog.csdn.net/u013457382/article/details/50956459                           

def getNumLeafs(myTree): #获取叶节点数目
    numLeafs = 0 #叶节点数初始化为0
    firstStr = list(myTree.keys())[0] #第一个节点为树的第一个键值
    secondDict = myTree[firstStr] #第一个key对应的value为其子树
    for key in secondDict.keys(): #对子树的每个孩子节点
        if type(secondDict[key])==dict: #如果当前子节点仍有子树
            numLeafs += getNumLeafs(secondDict[key]) #对该子节点递归调用此函数
        else: #否则说明是叶节点
            numLeafs += 1 #叶节点数加1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0] #第一个节点为树的第一个键值
    secondDict = myTree[firstStr] #第一个key对应的value为其子树
    for key in secondDict.keys(): #对子树的每个孩子节点
        if type(secondDict[key])==dict: #如果当前子节点仍有子树
            thisDepth = 1 + getTreeDepth(secondDict[key]) 
        else: #否则说明是叶节点
            thisDepth = 1 
        if thisDepth > maxDepth : maxDepth = thisDepth
    return maxDepth
    
def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #横坐标中值
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] #纵坐标中值
    createPlot.ax1.text(xMid, yMid, txtString) #在中间位置添加文本
    
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree) #叶节点数
    depth = getTreeDepth(myTree) #树高
    firstStr = list(myTree.keys())[0] #当前树的根节点
    cntrPt = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys(): #对当前树的每个子树
        if type(secondDict[key])==dict: #如果其仍有子树
            plotTree(secondDict[key], cntrPt, str(key)) #递归调用此函数
        else: #否则为叶节点,直接输出
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #
    
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree)) #宽度为叶节点数
    plotTree.totalD = float(getTreeDepth(inTree)) #高度为树高
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
    
dataSet, labels = createDataSet()
myTree = createTree(dataSet, labels)
createPlot(myTree)

代码0运行结果:
在这里插入图片描述

二.构造及绘制的各部分函数

没有写到文件里再调用,是写在Jupyter代码框里直接运行的

1.计算给定数据集的香农熵

代码1.0:

from math import log

def createDataSet(): #构建数据集
    dataSet = [[1, 1, 'yes'],
              [1, 1, 'yes'],
              [1, 0, 'no'],
              [0, 1, 'no'],
              [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def calcShannonEnt(dataSet):
    numEntries = len(dataSet) #数据集中数据的个数
    labelCounts = {} #记录各分类个数的字典
    for featVec in dataSet: #对数据集中的每个向量
        currentLabel = featVec[-1] #当前标记为当前向量的最后一位
        if currentLabel not in labelCounts.keys(): #如果当前标记尚未收录
            labelCounts[currentLabel] = 0 #初始化为0
        labelCounts[currentLabel] += 1 #已存在则计数加1
    shannonEnt = 0.0 #香农熵初始化为0
    for key in labelCounts: #对各种标记
        prob = float(labelCounts[key])/numEntries #把该标记出现的次数与总数的比看作概率
        shannonEnt -= prob*log(prob,2) # H = -∑P(xi)*logP(xi)
    return shannonEnt

if __name__=='__main__': #关于“__name__=='__main__'”可参考https://blog.konghy.cn/2017/04/24/python-entry-program/
    dataSet, labels = createDataSet()
    print(dataSet)
    print(calcShannonEnt(dataSet))

代码1.0运行结果:

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
0.9709505944546686 

熵越高则混合的数据越多,增加第三个名为maybe的分类测试熵变化:
代码1.1:

dataSet[0][-1] = 'maybe'
print(dataSet)
print(calcShannonEnt(dataSet))

代码1.1运行结果:

[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
1.3709505944546687

2.按照给定特征划分数据集

在代码1.1中改变了dataSet的值,需要把值变回来:
代码2.0:

dataSet[0][-1] = 'yes'
print(dataSet)

代码2.0运行结果:

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

代码2.1:

def splitDataSet(dataSet, axis, value): #待划分的数据集,划分数据集的特征,需要返回的特征的值
    retDataSet = [] #划分后的数据集列表
    for featVec in dataSet:
        if featVec[axis] == value: #如果当前数据此特征的值恰好为要返回的值
            reducedFeatVec = featVec[:axis] #axis前的元素
            reducedFeatVec.extend(featVec[axis+1:]) #axis之后的元素 #.extend()是把列表中元素逐个加入 #这两步即删除axis
            retDataSet.append(reducedFeatVec) #处理过的向量加入划分后的数据集列表 #.append()是把列表作为元素加入
    return retDataSet

print(dataSet)
print(splitDataSet(dataSet, 0, 0)) #划分出0号位置值为0的数据

代码2.1运行结果:

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'no'], [1, 'no']]

3.选择最好的数据集划分方式

遍历整个数据集,循环计算香农熵和splitDataSet()函数:
代码3:

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 #特征的个数为向量长度减1 #减去标记项
    baseEntropy = calcShannonEnt(dataSet) #作为比较的基准熵为数据集本身的熵
    bestInfoGain = 0.0; bestFeature = -1 #最大信息增益初始化为0,最佳划分特征编号初始化为-1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet] #数据集中每个数据i号特征构成的列表
        uniqueVals = set(featList) #把列表转换为集合 #即剔除其中重复值
        newEntropy = 0.0 #用i号特征划分数据集得到的熵初始化为0
        for value in uniqueVals: #对i号特征各个不同的值
            subDataSet = splitDataSet(dataSet, i, value) #选出i号特征为当前值的数据
            prob = len(subDataSet)/float(len(dataSet)) #i号特征当前值出现的概率
            newEntropy += prob*calcShannonEnt(subDataSet) #以i号特征作为划分的熵加上当前值出现的概率与i号特征为当前值的数据集的熵的乘积
        infoGain = baseEntropy - newEntropy #以i号特征划分得到的信息增益 #即基准熵减去当前熵
        if infoGain>bestInfoGain: #如果当前信息增益大于当前最大信息增益
            bestInfoGain = infoGain #更新最大信息增益
            bestFeature = i #最佳划分方式特征编号更新为i
        return bestFeature
    
print(chooseBestFeatureToSplit(dataSet))

代码3运行结果:

0

观察数据表可知第一个特征(即0号特征)相比于第二个特征对是否为鱼类有更大的参考价值

4.创建树的函数代码

如果数据集已经处理了所有属性,但类标签依然不是唯一的,用投票方法解决该叶子节点的分类:
代码4.0:

classList = [1, 0, 1, 0, 1] #作为测试,出现次数多的为1,故投票结果应为1

def majorityCnt(classList):
    classCount = {} #存储各分类结果的票数
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=lambda item:item[1], reverse=True)
    #字典元素按键值倒序排序 #sorted(待排元素,排序关键字,是否倒序) #lambda item:item[1]取每个元素编号为1的位置,即value值
    return sortedClassCount[0][0] #输出排序后的第一个元素的key值 #即出现次数最多的标记值

print(majorityCnt(classList))

代码4.0运行结果:

1

代码4.1:

def createTree(dataSet, labels): #递归函数,数据集不断划分不断调用
    classList = [example[-1] for example in dataSet] #当前数据集所有的标记值
    if classList.count(classList[0])==len(classList): #递归结束条件之一:第一个标记值出现次数为列表长度,即所有类别完全相同
        return classList[0] #直接返回该标记值即可
    if len(dataSet[0])==1: #递归结束条件之二:数据集中数据向量只剩标记值一项,即所有特征都已用过但类别仍不完全相同 #若相同则之前已返回
        return majorityCnt(classList) #调用投票函数,返回投票结果
    bestFeat = chooseBestFeatureToSplit(dataSet) #最佳划分特征的编号
    bestFeatLabel = labels[bestFeat] #最佳划分特征
    myTree = {bestFeatLabel:{}} #第一个选中的特征作为根节点 #其后依然是树,故用字典嵌套
    del(labels[bestFeat]) #删除特征集中已用过的特征
    featValues = [example[bestFeat] for example in dataSet] #数据集中各个数据此特征值组成的列表
    uniqueVals = set(featValues) #集合化,去除相同值
    for value in uniqueVals:
        subLabels = labels[:] #特征集已删除了选中特征,把现在的特征集作为递归调用的特征集
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value), subLabels)
        #树中当前特征当前值对应的子树为递归调用数据集如此划分后创建的树
    return myTree

dataSet, labels = createDataSet()
print(dataSet)
print(labels)
print(createTree(dataSet, labels))

代码4.1运行结果:

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
['no surfacing', 'flippers']
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

5.使用文本注解绘制树节点

用Matplotlib注解功能绘制树形图:
代码5:

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth",fc="0.8") #决策节点的格式
leafNode = dict(boxstyle="round4",fc="0.8") #叶节点的格式
arrow_args = dict(arrowstyle="<-") #箭头格式

def plotNode(nodeTxt, centerPt, parentPt, nodeType): #节点文本,节点坐标,父节点坐标,节点类型
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) #绘制带箭头的注解 
	#关于pyplot.annatate()可参考https://blog.csdn.net/u013457382/article/details/50956459                           

def createPlot():
    fig = plt.figure(1, facecolor='white') #创建一个新图形,白色
    fig.clf() #清空绘图区
    createPlot.ax1 = plt.subplot(111, frameon=False) #一行一列共一个图此时在绘制第一个图,不绘制边缘
    #关于plt.subplot()可参考https://blog.csdn.net/weixin_40490077/article/details/79523526?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-2&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-2
    plotNode('DecisionNode', (0.5,0.1), (0.1,0.5), decisionNode) #中文显示不出,所以用英文了
    plotNode('LeafNode', (0.8,0.1), (0.3,0.8), leafNode)
    plt.show()

createPlot()

代码5运行结果:
在这里插入图片描述

6.获取叶节点的数目和树的层数

需要知道有多少个叶节点来确定x轴长度,和树有多少层来确定y轴长度:
代码6:

def getNumLeafs(myTree): #获取叶节点数目
    numLeafs = 0 #叶节点数初始化为0
    firstStr = list(myTree.keys())[0] #第一个节点为树的第一个键值 #需要转化为列表才能按下标访问
    secondDict = myTree[firstStr] #第一个key对应的value为其子树
    for key in secondDict.keys(): #对子树的每个孩子节点
        if type(secondDict[key])==dict: #如果当前子节点仍有子树
            numLeafs += getNumLeafs(secondDict[key]) #对该子节点递归调用此函数
        else: #否则说明是叶节点
            numLeafs += 1 #叶节点数加1
    return numLeafs

def getTreeDepth(myTree): #获取树高
    maxDepth = 0 #最大树高初始化为0
    firstStr = list(myTree.keys())[0] #第一个节点为树的第一个键值 #需要转化为列表才能按下标访问
    secondDict = myTree[firstStr] #第一个key对应的value为其子树
    for key in secondDict.keys(): #对子树的每个孩子节点
        if type(secondDict[key])==dict: #如果当前子节点仍有子树
            thisDepth = 1 + getTreeDepth(secondDict[key]) 
        else: #否则说明是叶节点
            thisDepth = 1 #当前树高为1
        if thisDepth > maxDepth : maxDepth = thisDepth #如果当前树高大于最大树高则更新最大树高
    return maxDepth

dataSet, labels = createDataSet()
myTree = createTree(dataSet, labels)
print(getNumLeafs(myTree))
print(getTreeDepth(myTree))

代码6运行结果:

3
2

7.plotTree函数

代码7:

def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #横坐标中值
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] #纵坐标中值
    createPlot.ax1.text(xMid, yMid, txtString) #在中间位置添加文本
    
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree) #叶节点数
    depth = getTreeDepth(myTree) #树高
    firstStr = list(myTree.keys())[0] #当前树的根节点
    cntrPt = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys(): #对当前树的每个子树
        if type(secondDict[key])==dict: #如果其仍有子树
            plotTree(secondDict[key], cntrPt, str(key)) #递归调用此函数
        else: #否则为叶节点,直接输出
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #
    
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree)) #宽度为叶节点数
    plotTree.totalD = float(getTreeDepth(inTree)) #高度为树高
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
    
dataSet, labels = createDataSet()
myTree = createTree(dataSet, labels)
createPlot(myTree)

代码7运行结果:
在这里插入图片描述

三.测试和存储分类器

8.使用决策树的分类函数

使用构造好的决策树以及标签向量对实际数据向量进行分类:
代码8:

def classify(inputTree, featLabels, testVec): #递归函数,从决策树根节点起不断向下在输入向量中找到对应特征,直到得出结果
    firstStr = list(inputTree.keys())[0] #当前树的根节点标签字符
    secondDict = inputTree[firstStr] #根节点的子树
    featIndex = featLabels.index(firstStr) #当前判断的特征在特征向量中的下标
    for key in secondDict.keys(): #对此特征下对应的各个分类方向
        if testVec[featIndex]==key: #找到测试向量对应的那个方向
            if type(secondDict[key])==dict: #如果下面还有分类
                classLabel = classify(secondDict[key], featLabels, testVec) #对其之后对应的分类继续递归调用此函数
            else:
                classLabel = secondDict[key] #若已到叶节点则判断结束,classLabel返回给上层调用
    return classLabel


dataSet, labels = createDataSet()
print(labels)
myTree = createTree(dataSet, labels)
createPlot(myTree)
dataSet, labels = createDataSet() #createTree会删除labels中已用过的特征,故建树后labels中只剩一个特征
print("[1,0]",classify(myTree,labels,[1,0]))
print("[1,1]",classify(myTree,labels,[1,1]))

代码8运行结果:

['no surfacing', 'flippers']

在这里插入图片描述

[1,0] no
[1,1] yes

9.使用pickle模块存储决策树

为节省时间,用pickle序列化对象并保存在磁盘上,在需要时读取:
代码9:

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename,'wb') #必须用二进制方式打开,否则会报错
    pickle.dump(inputTree,fw) #pickle.dump(obj, file, [,protocol])序列化对象,把obj保存到file,protocol为序列化模式,默认为0
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename,'rb') #同样以二进制方式打开
    return pickle.load(fr) #pickle.load(file)反序列化对象,将文件中的数据解析为一个python对象

storeTree(myTree,'classifierStorage.txt')
grabTree('classifierStorage.txt')

代码9运行结果:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

运行后会发现同目录下多了一个名为classifierStorage.txt的文件

10.使用决策树预测隐形眼镜类型

隐形眼镜数据集:

young	myope	no	reduced	no lenses
young	myope	no	normal	soft
young	myope	yes	reduced	no lenses
young	myope	yes	normal	hard
young	hyper	no	reduced	no lenses
young	hyper	no	normal	soft
young	hyper	yes	reduced	no lenses
young	hyper	yes	normal	hard
pre	myope	no	reduced	no lenses
pre	myope	no	normal	soft
pre	myope	yes	reduced	no lenses
pre	myope	yes	normal	hard
pre	hyper	no	reduced	no lenses
pre	hyper	no	normal	soft
pre	hyper	yes	reduced	no lenses
pre	hyper	yes	normal	no lenses
presbyopic	myope	no	reduced	no lenses
presbyopic	myope	no	normal	no lenses
presbyopic	myope	yes	reduced	no lenses
presbyopic	myope	yes	normal	hard
presbyopic	hyper	no	reduced	no lenses
presbyopic	hyper	no	normal	soft
presbyopic	hyper	yes	reduced	no lenses
presbyopic	hyper	yes	normal	no lenses

代码10:

fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
createPlot(lensesTree)

代码10运行结果:

{'age': {'pre': {'prescript': {'hyper': {'astigmatic': {'no': {'tearRate': {'normal': 'soft', 'reduced': 'no lenses'}}, 'yes': 'no lenses'}}, 'myope': {'astigmatic': {'no': {'tearRate': {'normal': 'soft', 'reduced': 'no lenses'}}, 'yes': {'tearRate': {'normal': 'hard', 'reduced': 'no lenses'}}}}}}, 'presbyopic': {'prescript': {'hyper': {'astigmatic': {'no': {'tearRate': {'normal': 'soft', 'reduced': 'no lenses'}}, 'yes': 'no lenses'}}, 'myope': {'astigmatic': {'no': 'no lenses', 'yes': {'tearRate': {'normal': 'hard', 'reduced': 'no lenses'}}}}}}, 'young': {'tearRate': {'soft': 'soft', 'hard': 'hard', 'no lenses': 'no lenses'}}}}

在这里插入图片描述
不知道为什么和书上的结果不一样啊。。(:з」∠)

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

hqy_240603

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值