机器学习实战Ch03--决策树

# 决策树
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特诊数据
缺点:可能会产生过度匹配问题
使用数据类型:数值型和标称型
专家系统中,经常使用决策树
## trees.py
'''
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特诊数据
缺点:可能会产生过度匹配问题
使用数据类型:数值型和标称型
专家系统中,经常使用决策树
'''
from math import log
import operator

def createDataSet():
    # 数据集中两个特征'no surfacing','flippers', 数据的两个类标签'yes','no
    #dataSet是个list
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    return dataSet, labels

#计算给定数据集的熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)   #计算数据集中实例的总数
    labelCounts = {}            #创建空字典
    for featVec in dataSet:     #提取数据集每一行的特征向量
        currentLabel = featVec[-1]  #获取特征向量最后一列的标签
        # 检测字典的关键字key中是否存在该标签,如果不存在keys()关键字,将当前标签/0键值对存入字典中,并赋值为0
        #print(labelCounts.keys())
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        #print(labelCounts)
        labelCounts[currentLabel] += 1  #否则将当前标签对应的键值加1
        #print("%s="%currentLabel,labelCounts[currentLabel])
    shannonEnt = 0.0    #初始化熵为0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries   #计算各值出现的频率
        shannonEnt -= prob * log(prob,2)    #以2为底求对数再乘以出现的频率,即信息期望值
        #print("%s="%labelCounts[key],shannonEnt)
    return shannonEnt

#按照给定特征划分数据集:得到熵之后,还需划分数据集,以便判断当前是否正确地划分了数据集
#三个输入参数分别为:带划分的数据集,划分数据集的特征,需要返回的特征得值
#挑选出dataSet中axis位置值为value的剩余部分
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:  #筛选出dataSet中axis位置值为value
            #列表的索引中冒号的作用,a[1: ]表示该列表中的第1个元素到最后一个元素,而a[ : n]表示从第0歌元素到第n个元素(不包括n)
            reducedFeatVec = featVec[:axis] #取出特定位置前面部分并赋值给reducedFeatVec
            #print(featVec[axis+1:])
            #print(reducedFeatVec)
            reducedFeatVec.extend(featVec[axis+1:])     #取出特定位置后面部分并赋值给reducedFeatVec
            retDataSet.append(reducedFeatVec)
            #print(retDataSet)
    return retDataSet

#选择最好的数据集划分方式:选取特征,划分数据集,计算得出最好的划分数据集的特征
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #计算特征数量,即每一列表元素具有的列数,再减去最后一列为标签,故需减去1
    baseEntropy = calcShannonEnt(dataSet)       #计算信息熵,此处值为0.9709505944546686,此值将与划分之后的数据集计算的信息熵进行比较
    bestInfoGain = 0.0;bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]      #创建标签列表
        #print(featList)
        uniqueVals = set(featList)       #确定某一特征下所有可能的取值,set集合类型中的每个值互不相同
        #print(uniqueVals)
        newEntropy = 0.0
        for value in uniqueVals:        #计算每种划分方式的信息熵
            subDataSet = splitDataSet(dataSet, i, value)        #抽取该特征的每个取值下其他特征的值组成新的子数据集
            prob = len(subDataSet)/float(len(dataSet))      #计算该特征下的每一个取值对应的概率(或者说所占的比重)
            newEntropy += prob * calcShannonEnt(subDataSet)     #计算该特征下每一个取值的子数据集的信息熵,并求和
        infoGain = baseEntropy - newEntropy     #计算每个特征的信息增益
        #print("第%d个特征是的取值是%s,对应的信息增益值是%f"%((i+1),uniqueVals,infoGain))
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
            #print("第%d个特征的信息增益最大,所以选择它作为划分的依据,其特征的取值为%s,对应的信息增益值是%f"%((i+1),uniqueVals,infoGain))
    return bestFeature

#递归构建决策树,返回出现次数最多的分类名称
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

#创建树,参数为数据集和标签列表
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]       #提取dataset中的最后一列——种类标签
    #print(classList)
    if classList.count(classList[0]) == len(classList):    #计算classlist[0]出现的次数,如果相等,说明都是属于一类,不用继续往下划分
        return classList[0]     #递归结束的第一个条件是所有的类标签完全相同,则直接返回该类标签
    #print(dataSet[0])
    if len(dataSet[0]) == 1: #看还剩下多少个属性,如果只有一个属性,但是类别标签有多个,就直接用majoritycnt()进行整理,选取类别最多的作为返回值
        return majorityCnt(classList)   #递归结束的第二个条件是使用完了所有的特征,仍然不能将数据集划分成仅包含唯一类别的分组,则返回出现次数最多的类别
    bestFeat = chooseBestFeatureToSplit(dataSet)    #选取信息增益最大的特征作为下一次分类的依据
    bestFeatLabel = labels[bestFeat]     #选取特征对应的标签
    #print(bestFeatLabel)
    myTree = {bestFeatLabel:{}}  #创建tree字典,下一个特征位于第二个大括号内,循环递归
    del(labels[bestFeat])   #删除使用过的特征
    featValues = [example[bestFeat] for example in dataSet]     #特征值对应的该栏数据
    #print(featValues)
    uniqueVals = set(featValues)    #找到featvalues所包含的所有元素,去重复
    for value in uniqueVals:
        subLabels = labels[:]        #将使用过的标签删除更新后,赋值给新的列表,进行迭代
        #print(subLabels)
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value),subLabels) #循环递归生成树
    return myTree                            

#测试算法,使用决策树执行分类
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]    #找到树的第一个分类特征,或者说根节点'no surfacing'
    #print(firstStr)
    secondDict = inputTree[firstStr]    #从树中得到该分类特征的分支,有0和1
    #print(secondDict)
    featIndex = featLabels.index(firstStr)  #根据分类特征的索引找到对应的标称型数据值,'no surfacing'对应的索引为0
    #print(featIndex)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

#决策树的存储,使用pickle序列化对象,可在磁盘中保存对象。
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'wb')    #二进制写入'wb'
    pickle.dump(inputTree,fw)   #pickle的dump函数将决策树写入文件中
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename,'rb')    #对应于二进制方式写入数据,'rb'采用二进制形式读出数据
    return pickle.load(fr)
    

## trees_main.py

import trees
from imp import reload
import treePlotter

myDat,labels=trees.createDataSet()
#print(myDat)
#print(labels)
#print(trees.calcShannonEnt(myDat))

#熵越高,混合的数据就越多,如果我们在数据集中添加更多的分类,会导致熵结果增大
#myDat[1][-1]='maybe'#更改list中某一元素的值(除yes和no外的值),即为添加更多的分类,中括号中为对应元素行列的位置
#print(myDat)
#print(trees.calcShannonEnt(myDat))  #分类变多,熵增大

#append()和extend()两类方法的区别
a=[1,2,3]
b=[4,5,6]
a.append(b)
#print(a)#[1, 2, 3, [4, 5, 6]]
a.extend(b)
#print(a)#[1, 2, 3, [4, 5, 6], 4, 5, 6]

#按照给定特征划分数据集
#print(myDat)
#print(trees.splitDataSet(myDat,0,1))
#print(trees.splitDataSet(myDat,0,0))

#选择最好的数据集划分方式
#print(myDat)
#print(trees.chooseBestFeatureToSplit(myDat))

#创建树,参数为数据集和标签列表
myTree=trees.createTree(myDat,labels)
#print(myTree)

myDat,labels=trees.createDataSet()
myTree1=treePlotter.retrieveTree(0)
#print(myTree1)
#print(trees.classify(myTree1,labels,[1,0]))
#print(trees.classify(myTree,labels,[1,1]))

#决策树的存储
trees.storeTree(myTree,'classifierStorage.txt')
#print(trees.grabTree('classifierStorage.txt'))

#使用决策树预测隐形眼镜类型
fr=open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]  #将文本数据的每一个数据行按照tab键分割,并依次存入lenses
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']   # 创建并存入特征标签列表
lensesTree = trees.createTree(lenses, lensesLabels)   # 根据继续文件得到的数据集和特征标签列表创建决策树
print(lensesTree)
treePlotter.createPlot(lensesTree)

## treePlotter.py

'''
python中使用Matplotlib注解绘制树形图
'''
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细,pad指的是外边框锯齿形(圆形等)的大小
leafNode = dict(boxstyle="round4", fc="0.8")    #定义决策树的叶子结点的描述属性,round4表示圆形
arrow_args = dict(arrowstyle="<-")  #定义箭头属性

#annotate是关于一个数据点的文本
#nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
#annotate的作用是添加注释,nodetxt是注释的内容
#nodetype指的是输入的节点(边框)的形状
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#def createPlot():
#   fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()


#构造注解树,需要知道叶节点的个数,以便可以正确确定x轴的长度;要知道树的层数,可以确定y轴的高度。
def getNumLeafs(myTree):    #计算叶子节点的个数
    numLeafs = 0
    firstStr = list(myTree.keys())[0]  #获得myTree的第一个键值,即第一个特征,分割的标签
    #print(firstStr)
    secondDict = myTree[firstStr]   #根据键值得到对应的值,即根据第一个特征分类的结果
    #print(secondDict)
    for key in secondDict.keys():   #获取第二个小字典中的key
        if type(secondDict[key]).__name__=='dict':
            #判断是否小字典中是否还包含新的字典(即新的分支)
            numLeafs += getNumLeafs(secondDict[key])    #包含的话进行递归从而继续循环获得新的分支所包含的叶节点的数量
        else:   numLeafs +=1    #不包含的话就停止迭代并把现在的小字典加一表示这边有一个分支
    return numLeafs

def getTreeDepth(myTree):   #计算判断节点的个数
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

#预先存储树信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#作用是计算tree的中间位置,cntrPt起始位置,parentPt终止位置,txtString文本标签信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  #cntrPt起点坐标,子节点坐标,parentPt结束坐标,父节点坐标
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  #找到x和y的中间位置
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)   #计算子节点的坐标
    plotMidText(cntrPt, parentPt, nodeTxt)      #绘制线上的文字
    plotNode(firstStr, cntrPt, parentPt, decisionNode)      #绘制节点
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD     #每绘制一次图,将y的坐标减少1.0/plottree.totald,间接保证y坐标上深度的
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD



def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  #类似于Matlab的figure,定义一个画布,背景为白色
    fig.clf()   # 把画布清空
    axprops = dict(xticks=[], yticks=[])    #subplot定义了一个绘图
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    # createPlot.ax1为全局变量,绘制图像的句柄,111表示figure中的图有1行1列,即1个,最后的1代表第一个图,frameon表示是否绘制坐标轴矩形
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

## treePlotter_main.py

import  treePlotter


#treePlotter.createPlot()
#print(treePlotter.retrieveTree(1))
myTree=treePlotter.retrieveTree(0)
#print(treePlotter.getNumLeafs(myTree))
#print(treePlotter.getTreeDepth(myTree))
myTree['no surfacing'][3]='maybe'
treePlotter.createPlot(myTree)


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
完整版:https://download.csdn.net/download/qq_27595745/89522468 【课程大纲】 1-1 什么是java 1-2 认识java语言 1-3 java平台的体系结构 1-4 java SE环境安装和配置 2-1 java程序简介 2-2 计算机中的程序 2-3 java程序 2-4 java类库组织结构和文档 2-5 java虚拟机简介 2-6 java的垃圾回收器 2-7 java上机练习 3-1 java语言基础入门 3-2 数据的分类 3-3 标识符、关键字和常量 3-4 运算符 3-5 表达式 3-6 顺序结构和选择结构 3-7 循环语句 3-8 跳转语句 3-9 MyEclipse工具介绍 3-10 java基础知识章节练习 4-1 一维数组 4-2 数组应用 4-3 多维数组 4-4 排序算法 4-5 增强for循环 4-6 数组和排序算法章节练习 5-0 抽象和封装 5-1 面向过程的设计思想 5-2 面向对象的设计思想 5-3 抽象 5-4 封装 5-5 属性 5-6 方法的定义 5-7 this关键字 5-8 javaBean 5-9 包 package 5-10 抽象和封装章节练习 6-0 继承和多态 6-1 继承 6-2 object类 6-3 多态 6-4 访问修饰符 6-5 static修饰符 6-6 final修饰符 6-7 abstract修饰符 6-8 接口 6-9 继承和多态 章节练习 7-1 面向对象的分析与设计简介 7-2 对象模型建立 7-3 类之间的关系 7-4 软件的可维护与复用设计原则 7-5 面向对象的设计与分析 章节练习 8-1 内部类与包装器 8-2 对象包装器 8-3 装箱和拆箱 8-4 练习题 9-1 常用类介绍 9-2 StringBuffer和String Builder类 9-3 Rintime类的使用 9-4 日期类简介 9-5 java程序国际化的实现 9-6 Random类和Math类 9-7 枚举 9-8 练习题 10-1 java异常处理 10-2 认识异常 10-3 使用try和catch捕获异常 10-4 使用throw和throws引发异常 10-5 finally关键字 10-6 getMessage和printStackTrace方法 10-7 异常分类 10-8 自定义异常类 10-9 练习题 11-1 Java集合框架和泛型机制 11-2 Collection接口 11-3 Set接口实现类 11-4 List接口实现类 11-5 Map接口 11-6 Collections类 11-7 泛型概述 11-8 练习题 12-1 多线程 12-2 线程的生命周期 12-3 线程的调度和优先级 12-4 线程的同步 12-5 集合类的同步问题 12-6 用Timer类调度任务 12-7 练习题 13-1 Java IO 13-2 Java IO原理 13-3 流类的结构 13-4 文件流 13-5 缓冲流 13-6 转换流 13-7 数据流 13-8 打印流 13-9 对象流 13-10 随机存取文件流 13-11 zip文件流 13-12 练习题 14-1 图形用户界面设计 14-2 事件处理机制 14-3 AWT常用组件 14-4 swing简介 14-5 可视化开发swing组件 14-6 声音的播放和处理 14-7 2D图形的绘制 14-8 练习题 15-1 反射 15-2 使用Java反射机制 15-3 反射与动态代理 15-4 练习题 16-1 Java标注 16-2 JDK内置的基本标注类型 16-3 自定义标注类型 16-4 对标注进行标注 16-5 利用反射获取标注信息 16-6 练习题 17-1 顶目实战1-单机版五子棋游戏 17-2 总体设计 17-3 代码实现 17-4 程序的运行与发布 17-5 手动生成可执行JAR文件 17-6 练习题 18-1 Java数据库编程 18-2 JDBC类和接口 18-3 JDBC操作SQL 18-4 JDBC基本示例 18-5 JDBC应用示例 18-6 练习题 19-1 。。。
完整版:https://download.csdn.net/download/qq_27595745/89522468 【课程大纲】 1-1 什么是java 1-2 认识java语言 1-3 java平台的体系结构 1-4 java SE环境安装和配置 2-1 java程序简介 2-2 计算机中的程序 2-3 java程序 2-4 java类库组织结构和文档 2-5 java虚拟机简介 2-6 java的垃圾回收器 2-7 java上机练习 3-1 java语言基础入门 3-2 数据的分类 3-3 标识符、关键字和常量 3-4 运算符 3-5 表达式 3-6 顺序结构和选择结构 3-7 循环语句 3-8 跳转语句 3-9 MyEclipse工具介绍 3-10 java基础知识章节练习 4-1 一维数组 4-2 数组应用 4-3 多维数组 4-4 排序算法 4-5 增强for循环 4-6 数组和排序算法章节练习 5-0 抽象和封装 5-1 面向过程的设计思想 5-2 面向对象的设计思想 5-3 抽象 5-4 封装 5-5 属性 5-6 方法的定义 5-7 this关键字 5-8 javaBean 5-9 包 package 5-10 抽象和封装章节练习 6-0 继承和多态 6-1 继承 6-2 object类 6-3 多态 6-4 访问修饰符 6-5 static修饰符 6-6 final修饰符 6-7 abstract修饰符 6-8 接口 6-9 继承和多态 章节练习 7-1 面向对象的分析与设计简介 7-2 对象模型建立 7-3 类之间的关系 7-4 软件的可维护与复用设计原则 7-5 面向对象的设计与分析 章节练习 8-1 内部类与包装器 8-2 对象包装器 8-3 装箱和拆箱 8-4 练习题 9-1 常用类介绍 9-2 StringBuffer和String Builder类 9-3 Rintime类的使用 9-4 日期类简介 9-5 java程序国际化的实现 9-6 Random类和Math类 9-7 枚举 9-8 练习题 10-1 java异常处理 10-2 认识异常 10-3 使用try和catch捕获异常 10-4 使用throw和throws引发异常 10-5 finally关键字 10-6 getMessage和printStackTrace方法 10-7 异常分类 10-8 自定义异常类 10-9 练习题 11-1 Java集合框架和泛型机制 11-2 Collection接口 11-3 Set接口实现类 11-4 List接口实现类 11-5 Map接口 11-6 Collections类 11-7 泛型概述 11-8 练习题 12-1 多线程 12-2 线程的生命周期 12-3 线程的调度和优先级 12-4 线程的同步 12-5 集合类的同步问题 12-6 用Timer类调度任务 12-7 练习题 13-1 Java IO 13-2 Java IO原理 13-3 流类的结构 13-4 文件流 13-5 缓冲流 13-6 转换流 13-7 数据流 13-8 打印流 13-9 对象流 13-10 随机存取文件流 13-11 zip文件流 13-12 练习题 14-1 图形用户界面设计 14-2 事件处理机制 14-3 AWT常用组件 14-4 swing简介 14-5 可视化开发swing组件 14-6 声音的播放和处理 14-7 2D图形的绘制 14-8 练习题 15-1 反射 15-2 使用Java反射机制 15-3 反射与动态代理 15-4 练习题 16-1 Java标注 16-2 JDK内置的基本标注类型 16-3 自定义标注类型 16-4 对标注进行标注 16-5 利用反射获取标注信息 16-6 练习题 17-1 顶目实战1-单机版五子棋游戏 17-2 总体设计 17-3 代码实现 17-4 程序的运行与发布 17-5 手动生成可执行JAR文件 17-6 练习题 18-1 Java数据库编程 18-2 JDBC类和接口 18-3 JDBC操作SQL 18-4 JDBC基本示例 18-5 JDBC应用示例 18-6 练习题 19-1 。。。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值