实战西瓜集 决策树及其注释

有些注释不够严谨,见谅

from math import log  #引入math的log函数
import operator #引入python的内部操作函数
import matplotlib.pyplot as plt #引入绘画库
'''
定义文本框和箭头格式
正常显示汉字
'''
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
plt.rcParams["font.sans-serif"] = ["FangSong"]

#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet) #计算数据集的实例的总数
    labelCounts = {} #创建一个字典
    for featVec in dataSet: #计算数据的每行数据
        currentLabel = featVec[-1] #将类别赋给currentLabel
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 #若是当前键值不存在,则扩展字典将当前值加入字典
        labelCounts[currentLabel] += 1 #记录当前类别的次数
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #以2为底求对数,计算香农熵
    return shannonEnt #返回香农熵

#按照给定特征划分数据集
def splitDataSet(dataSet, axis, value): #带划分数据集、特征、特征的返回值
    retDataSet = []
    for featVec in dataSet: #用每个数据集进行划分
        if featVec[axis] == value: #符合数据的特征
            reducedFeatVec = featVec[:axis]  #将除了这一数据特征外全部添加到创建的数列中去
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet #返回数列

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1   #计算多少个数据特征
    baseEntropy = calcShannonEnt(dataSet) #计算香农熵
    bestInfoGain = 0.0; bestFeature = -1 
    for i in range(numFeatures):        #遍历所有特征
        featList = [example[i] for example in dataSet] # 将每个特征写入列表
        uniqueVals = set(featList)       #将列表转换为矩阵
        newEntropy = 0.0
        for value in uniqueVals: # 遍历特征
            subDataSet = splitDataSet(dataSet, i, value) #划分数据集
            prob = len(subDataSet)/float(len(dataSet)) 
            newEntropy += prob * calcShannonEnt(subDataSet)   #对唯一特征值得到的熵求和
        infoGain = baseEntropy - newEntropy     # 计算信息增益
        if (infoGain > bestInfoGain):       # 比较所有的信息增益得到最佳
            bestInfoGain = infoGain         # 将最佳的信息增益赋给bestinfogain
            bestFeature = i
    return bestFeature #返回最佳信息增益

#创建树的函数代码
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet] #遍历每行数据的每个参数
    if classList.count(classList[0]) == len(classList): 
        return classList[0]#当所有类都相等时停止拆分
    if len(dataSet[0]) == 1: #当只有一个特征时停止拆分
        return majorityCnt(classList) #开始另外的分类方法
    bestFeat = chooseBestFeatureToSplit(dataSet)  #开始计算最好的数据集划分方法
    bestFeatLabel = labels[bestFeat] #将最好的分类特征赋值
    myTree = {bestFeatLabel:{}}  #建立一个字典用来存储树的信息
    del(labels[bestFeat])  #删除建好的分类特征
    featValues = [example[bestFeat] for example in dataSet] #将所有属性放到列表中
    uniqueVals = set(featValues) #将存放所有属性的列表转换为字典
    for value in uniqueVals:  #遍历所有属性
        subLabels = labels[:]       #复制类标签,将其存储到列表变量subLabels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)#递归调用creaTree()函数,将结果返回给myTree
    return myTree

#多数表决测分类
def majorityCnt(classList):
    classCount={}  #创建空字典,用来存储类标签出现的频率
    for vote in classList:  #开始遍历列表的每一个参数
        if vote not in classCount.keys(): classCount[vote] = 0  #如果是不在字典里,就赋值0,否则加1
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#操作键值排序字典
    return sortedClassCount[0][0] #返回次数最多的分类名称

def plotNode(nodeTxt, centerPt, parentPt, nodeType): #给图形添加文本注解(内容、被注释的坐标点、指定数据点的坐标系
    #注释内容的位置、注释内容的坐标系、垂直中心、水平中心、注释框的风格和颜色宽度、指定箭头的风格
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#计算子节点的个数
def getNumLeafs(myTree):
    numLeafs = 0  #初始化叶子数
    firstStr = list(myTree.keys())[0] #获取节点类别标签
    secondDict = myTree[firstStr] #获取子节点myTree字典的第二层
    for key in secondDict.keys():  #遍历所有子节点
        if type(secondDict[key]).__name__=='dict':#若子节点为字典,则不是子节点
            numLeafs += getNumLeafs(secondDict[key]) #递归查找子节点
        else:   numLeafs +=1 #将叶子数加1
    return numLeafs  #返回叶子数

#计算遍历过程中遇到判断节点的个数,即深度
def getTreeDepth(myTree):
    maxDepth = 0  #初始化深度为0
    firstStr =list(myTree.keys())[0] #获取节点类别标签
    secondDict = myTree[firstStr] #获取子节点 myTree字典的第二层
    for key in secondDict.keys(): #遍历子节点键值
        if type(secondDict[key]).__name__=='dict':#如是字典类型,继续遍历
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1 #若不是字典类型,则为叶子节点
        if thisDepth > maxDepth: maxDepth = thisDepth  #判断是否更新层数
    return maxDepth  #返回层数


'''
输入预先储存的树信息,避免每次都要创建树
主要用于测试
输出:一个预定义的树结构
'''
#预先存储树信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#计算父节点和子节点
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]#计算父子节点连线中间坐标 计算标注位置
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)#填充文本

#计算树的宽和高
def plotTree(myTree, parentPt, nodeTxt):#
    numLeafs = getNumLeafs(myTree)  #获取决策树叶子节点的数目,决定了树的宽度
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #获取节点类别标签
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)#计算中心节点坐标
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)#绘制从parentPt指向cntrPt的箭头
    secondDict = myTree[firstStr]#获取子节点 myTree字典的第二层
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#减少y偏移
    for key in secondDict.keys():#遍历所有的子节点
        if type(secondDict[key]).__name__=='dict':#判断子节点是否为字典类型,如果不是则是叶子节点   
            plotTree(secondDict[key],cntrPt,str(key))        #如果子节点为字典类型,则该节点也是一个判断节点,递归调用plotTree,继续寻找叶子节点
        else:   #如果子节点不为字典类型,则为叶子节点
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW#减少x偏移
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)#绘制中心节点指向子节点箭头(叶子节点)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))#填充中心节点指向子节点的信息(叶子节点)
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#减少y偏移
#if you do get a dictonary you know it's a tree, and the first element will be another dict
    
#计算图树图形的全局尺寸
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white') #创建一个白色的画布
    fig.clf() #清除图像上的就图形
    axprops = dict(xticks=[], yticks=[]) #将标签转化为字典
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #创建一个子画板,不显示边框
    #createPlot.ax1 = plt.subplot(111, frameon=False) #显示标记
    plotTree.totalW = float(getNumLeafs(inTree)) #计算决策树的宽度、高度
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #跟踪已绘制的节点位置
    plotTree(inTree, (0.5,1.0), '') #开始绘制树的宽和高
    plt.show() #将图像显示出来

#使用决策树的分类函数
def classify(inputTree,featLabels,testVec):  #树模型、特征集、待分类特征
    firstStr = inputTree.keys()[0] #取树模型的第一个键值,即第一个特征值
    secondDict = inputTree[firstStr]  #获取字典的第二层
    featIndex = featLabels.index(firstStr)  #通过index函数找到特征集中符合第一个特征的索引值,赋值
    key = testVec[featIndex]  #遍历子树中键
    valueOfFeat = secondDict[key]  #判断待分类特征的第一个特征值是否为树模型的一级节点的子节点
    if isinstance(valueOfFeat, dict):  #判断孩子节点是否为子节点类型,如果是,则不能确定出分类,需递归调用
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat #若不是字典类型,则可以判断出分类,此步是递归函数的必经之路
    return classLabel

if __name__=='__main__':
  fr=open('xigua.txt',encoding='utf-8',errors='ignore')#打开文件,设置utf-8编码格式,错误的处理方式为忽视
  lenses=[inst.strip().split('\t') for inst in fr.readlines()]#从样本集文件中读取所有的行,用换行符分开,去除每行行首和行末的空格,保存到列表变量lenses中
  lensesLabels=['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']#定义样本集的特征集
  lensesTree=createTree(lenses,lensesLabels)#调用模块createTree的函数对样本集产生决策树
  createPlot(lensesTree)#创建出决策树的图形

结果:
显示结果
注意:数据格式的编写:好瓜属性一定不要写,或者属性都不写也行

色泽 根蒂 敲声 纹理 脐部 触感
青绿 蜷缩 浊响 清晰 凹陷 硬滑 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软粘 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 蜷缩 浊响 模糊 平坦 软粘 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值