机器学习实战(二):决策树

决策树

算法简介
组成:长方形为判断模块,椭圆形为终止模块,
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征值数据。
缺点:可能会产生过的匹配问题
适用数据类型:数据型和标称型。
算法实现
决策树的一般流程

  1. 收集数据:可使用任何方法。
  2. 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
  3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预测。
  4. 训练算法:构造树的数据结构。
  5. 测试算法:使用经验树计算错误率。
  6. 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

利用递归算法创建决策树
创建分支伪代码函数createBranch():

检测数据集中的每个子项是否属于同于分类:
	If so return 类标签
	Else 
		寻找划分数据集的最好特征
		划分数据集
		创建分支节点
			For每个划分的子集
				调用函数createBranch并增加返回结果到分支节点中
		Return 分支节点

数据划分
问题:如何采用量化的方法判断如何划分数据
原则:将无序的数据变得更加有序
判断方法:按照获取最大信息增益的方法划分数据(还可以采用基尼不纯度划分)
信息增益:划分数据集之前之后信息发生的变化
熵:集合信息的度量方式,计算方法如下所示:
符号 的信息定义: 在这里插入图片描述

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值:
在这里插入图片描述

数据划分算法
一些决策树采用二分法进行数据划分,此处采用id3算法。
ID3算法是决策树的一种,它是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息
增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
进一步了解id3算法
代码实现
代码分为构造决策树、绘制决策树两个部分
构造决策树
创建名为trees.py的文件

from math import log
import operator

"""
函数calcShannonEnt(dataSet):
参数:dataSet:待处理的数据集
功能:计算给定数据集的香农熵
"""
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)#计算数据集中实例总数
    labelCounts = {}
    #遍历数据集,若当前键值不存在,则扩展字典并将当前键值加入字典,并统计各类出现次数
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries #计算p(xi)
        shannonEnt -= prob * log(prob,2)#计算香农熵
    return shannonEnt

"""
函数splitDataSet(dataSet, axis, value):
参数:dataSet:带划分的数据集
     axis:划分数据集的特征
     value:需要返回的特征的值
功能:按照给定的特征值划分数据集
"""
def splitDataSet(dataSet, axis, value):
    retDataSet = []#创建新的list对象,避免函数对原数据集的修改
    #将符合特征的数据抽取出来
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet

"""
函数chooseBestFeatureToSplit(dataSet):
参数:dataSet:待处理的数据集
功能:选择最好的数据集划分方式
在函数中调用的数据有一定的要求:
1.数据必须是一种由列表元素组成的列表,且所有的列表元素都要具有相同的数据长度;
2.数据的最后一列或每个实例最后一个元素是当前实例的类别标签。
"""
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1#计算数据特征总数,例:【属性A,属性B...属性N,类别】共N+1项,有N个特征
    baseEntropy = calcShannonEnt(dataSet)#数据集原始香农熵
    bestInfoGain = 0.0
    BestFeature = -1
    #遍历数据集中的所有特征
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]#将数据集中所有第i个特征值的可能值写入featList
        uniqueVals = set(featList)#选取不重复的特征值集合,set是python中得到列表中唯一元素值的最快方法
        newEntropy = 0.0
        #遍历不重复的特征值集合,对每个特征值划分一次数据集计算信息增益
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        #比较所有的特征中的信息增益,返回最好划分的索引值
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            BestFeature = i
    return BestFeature

"""
函数majorityCnt(classList):
参数:classList:类标签列表
功能:选出类标签中最多的一类
"""
def majorityCnt(classList):
    classCount = {}
    #记录各类出现的次数
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    #对各类出现次数排序
    sortedClassCount = sorted(classCount.items(), 
                              key = operator.itemgetter(1), 
                              reverse=True)
    return sortedClassCount[0][0]

"""
函数createTree(dataSet, labels):
参数:dataSet:数据集
     labels:标签列表
功能:建立决策树
"""
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    #递归条件1:如果类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    #递归条件2:遍历完所有特征值时返回出现次数最多的类别
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    #开始创建树
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}#此处采用嵌套词典的存储树
    #得到列表包含的所有属性值
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]#为了不改变原始列表内容,复制类标签,将其存储在新列表变量subLabels中
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

"""
函数classify(inputTree, featLabels, testVec):
参数:inputTree:数据集
     featLabels:特征标签列表
     testVec:测试向量
功能:使用决策树分类
"""
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

"""
用于存储和读取决策树
"""
def storeTree(inputTree, filename):
    import pickle 
    #fw = open(filename, 'w')
    fw = open(filename, mode = 'wb')
    pickle.dump(inputTree, fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)

绘制决策树
创建名为treePlotter.py的文件

import matplotlib.pyplot as plt

#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle = "<-")

#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.axl.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', 
                            xytext = centerPt, textcoords = 'axes fraction', 
                            va = "center", ha = "center", bbox = nodeType, 
                            arrowprops = arrow_args)
"""
#此为简单的示例函数,后面会拓展为完整的绘制函数
def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.axl = plt.subplot(111, frameon=False)
    plotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()
"""

"""
函数getNumLeafs,getTreeDepth
参数:myTree:已生成的决策树
功能:获取叶节点的数目和树的层数
"""
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]#必须用list将myTree.keys()类型转换,否则为dict_key类型,无法作为list使用
    secondDict = myTree[firstStr]
    #遍历树的所有子节点
    for key in secondDict.keys():
        #若子节点类型为字典,则该节点为判断节点,需递归调用getNumLeafs
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs += 1
    return numLeafs
    
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


#绘制决策树
def plotMidText(cntrPt, parentPt, txtString):
    #在父子节点间填充文本信息
    xMid = (parentPt[0] - cntrPt[0])/ 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/ 2.0 + cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)
    
def plotTree(myTree, parentPt, nodeTxt):
    #计算宽和高
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #标记子节点属性值
    plotMidText(cntrPt, parentPt, nodeTxt)
    #减少y的偏移
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/ plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
         
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.axl = plt.subplot(111, frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))#存储树的宽度
    plotTree.totalD = float(getTreeDepth(inTree))#存储树的深度
    plotTree.xOff = -0.5/plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
    

测试代码
源码、数据集下载地址

import trees
import treePlotter

#用于测试的dataSet, labels的
def createDataSet():
    dataSet = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no']
    ]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

#用于测试的决策树
def retrieveTree(i):
    listOfTrees = [
        {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
        {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
        ]
    return listOfTrees[i]
"""
#测试函数calcShannonEnt
myDat, labels = createDataSet()
print(myDat)
x1 = trees.calcShannonEnt(myDat)
print(x1)#如果你的程序正确,x1=0.9709505944546686
#数据混合越多,熵越高,增加一个分类,查看熵的变化
myDat[0][-1]='maybe'
x1 = trees.calcShannonEnt(myDat)
print(x1)#如果你的程序正确,x1=1.3709505944546687
"""

"""
#方法append,extend的区别
a = [1, 2, 3]
b = [4, 5, 6]
a.append(b)#b作为一个元素加入列表a中,应得到[1, 2, 3, [4, 5, 6]]
print(a)
a = [1, 2, 3]#若不重写a,a将为[1, 2, 3, [4, 5, 6]]
a.extend(b)#得到一个包含a和b所有元素的列表,[1, 2, 3, 4, 5, 6]
print(a)
"""

"""
#测试函数splitDataSet
myDat, labels = createDataSet()
print(myDat)#[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
x1 = trees.splitDataSet(myDat, 0, 1)
x2 = trees.splitDataSet(myDat, 0, 0)
print(x1)#[[1, 'yes'], [1, 'yes'], [0, 'no']]
print(x2)#[[1, 'no'], [1, 'no']]
"""

"""
#测试函数chooseBestFeatureToSplit
myDat, labels = createDataSet()
x1 = trees.chooseBestFeatureToSplit(myDat)
print(x1)#返回值应为:0
"""

"""
#测试函数createTree
myDat, labels = createDataSet()
myTree = trees.createTree(myDat, labels)
print(myTree)
#报错:IndexError: list index out of range,代码编写有问题括号不匹配
"""

"""
#测试函数createPlot
treePlotter.createPlot()
"""

"""
#测试函数函数getNumLeafs,getTreeDepth
myTree = retrieveTree(0)
x1 = treePlotter.getNumLeafs(myTree)
x2 = treePlotter.getTreeDepth(myTree)
print(x1)#结果为3
print(x2)#结果为2
"""

"""
#测试函数createPlot
myTree = retrieveTree(0)
#treePlotter.createPlot(myTree)
myTree ['no surfacing'][3] = 'maybe'
treePlotter.createPlot(myTree)
"""

"""
#测试函数classify
myDat, labels = createDataSet()
myTree = retrieveTree(0)
print(labels)
print(myTree)
x1 = trees.classify(myTree, labels, [1,0])
x2 = trees.classify(myTree, labels, [1,1])
print(x1)#应为:no
print(x2)#应为:yes
"""

"""
#测试函数storeTre,grabTree
myTree = retrieveTree(0)
trees.storeTree(myTree, 'classifierStorage.txt')
x = trees.grabTree('classifierStorage.txt')
print(x)
"""

#示例:使用决策树预测隐形眼镜类型
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)

示例:使用决策树预测隐形眼镜类型结果为
在这里插入图片描述
总结

  1. 决策树存在过度匹配的问题,为了解决这个问题需要对决策树进行修剪,相关问题会在后续学习过程中进行学习;
  2. 我使用的构造算法为ID3算法,无法直接处理数据型,还有其他决策树构造算法,如:C4.5和CART,之后会去进一步了解这些算法。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值