机器学习之决策树详解2--02

决策树的code:

# coding=utf-8
from math import log
import operator

import math
import matplotlib.pyplot as plt


'''
    对于海洋生物的数据,进行决策树分类
'''
def createDataSet():

    '''
        第一列 不浮出水面是否可以生存 no surfacing
        第二列 是否有脚 flippers
        第三列 是否属于鱼
    '''

    dataSet = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no']
    ]
    labels = ['no surfacing','flippers']
    return dataSet,labels

输出数据:

dataSet,labels = createDataSet()
print(dataSet)
print(labels)

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
['no surfacing', 'flippers']

计算给顶数据集的熵:

#计算给定数据集的熵
def calcShannoEnt(dataSet):
    numEntries = len(dataSet) #统计元素的个数
    labelsCounts = {}  #标签计数

    for featVec in dataSet:#每个标签对应的个数
        currentLabel = featVec[-1]
        if currentLabel not in labelsCounts.keys():
            labelsCounts[currentLabel] = 0
        labelsCounts[currentLabel] += 1

    shannonEnt = 0.0 #熵的变量
    for key in labelsCounts: #计算熵
        prob = labelsCounts[key]/numEntries
        shannonEnt -= prob*math.log(prob,2)

    return shannonEnt
shannonEnt = calcShannoEnt(dataSet)

输出结果: 0.9709505944546686


按照给定特征划分数据集:

def splitDateSet(dataSet, axis, value):

    retDateSet = []

    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVev = featVec[ : axis]   # 获取列表axis前面的元素
            reduceFeatVev.extend(featVec[axis + 1 : ]) #获取列表axis后面的元素
            retDateSet.append(reduceFeatVev)

    return retDateSet
print(splitDateSet(dataSet,0,1))
print(splitDateSet(dataSet, 0, 0))

输出结果:

    [[1, 'yes'], [1, 'yes'], [0, 'no']]

    [[1, 'no'], [1, 'no']]


选择最好的数据集划分方式:


def chooseBestFeatureToSplit(dataSet):

    numFeatrues = len(dataSet[0]) -1  #特征个数

    baseEntropy = calcShannoEnt(dataSet) #计算总熵

    baseInfoGain = 0.0 #信息增益

    baseFeature = -1 #最好特征存放变量

    for i in range(numFeatrues) : #特征循环的控制

        featList = [example[i] for example in dataSet]  #存在某一个特征的所有样本
        uniqueVals = set(featList) #每一个特征,含有不同value        newEntropy = 0.0

        for value in uniqueVals:  #公式的嵌套
            subDataSet = splitDateSet(dataSet, i, value)
            prob = len(subDataSet) / len(dataSet)
            newEntropy += prob*calcShannoEnt(subDataSet)
        infoGain = baseEntropy - newEntropy

        if (infoGain > baseInfoGain) :
            baseInfoGain = infoGain
            bestFeat = i
    return bestFeat
print(chooseBestFeatureToSplit(dataSet))

输出结果: 0


构造树:

 
def majorityCnt(classList):
    classCount = {}

   # print(classList)
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]

def createTree(dataSet,labels):

    classList = [example[-1] for example in dataSet]   #获取

    if classList.count(classList[0]) == len(classList): #类别完全相同,停止划分
        return classList[0]

    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)  #选取最好特征
    bestFeatLabel = labels[bestFeat]  #最好特征的标签

    myTree = {bestFeatLabel:{}}

    del(labels[bestFeat]) #已经查找过的特征,进行删除

    featValues = [example[bestFeat] for example in dataSet]  #最好特征的value    uniqueVals = set(featValues) # 单个不同value

    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDateSet(dataSet, bestFeat, value), subLabels)
        #myTree 好好的理解
    return myTree

myTree = createTree(dataSet, labels)
print(myTree)

输出结果:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}


画出决策树:

def getNumLeafs(myTree): #获取叶子节点数目

    numLeafs = 0

    firstStr = list(myTree.keys())[0]

    '''
        TypeError: ‘dict_keys’ object does not support indexing
        这个问题是python版本的问题
            #如果使用的是python2
            firstStr = myTree.keys()[0]
            #LZ使用的是python3
            firstSides = list(myTree.keys())
            firstStr = firstSides[0]
        是看决策树代码出现的问题,python3如果运行
        firstStr = myTree.keys()[0]

        就会报这个错误,解决办法就是先转换成list,再把需要的索引提取出来。
    
    '''

    secondDict = myTree[firstStr]

    for key in secondDict.keys(): #测试节点的数据类型是否为字典,若为字典则,不是叶子节点,否则是叶子节点
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1

    return numLeafs

def getTreeDepth(myTree): #获取树的深度

    maxDepth = 0

    firstStr = list(myTree.keys())[0]

    secondDict = myTree[firstStr]

    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                           xytext=centerPt, textcoords='axes fraction',
                           va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


decisionNode = dict(boxstyle='sawtooth', fc='0.8')

leafNode = dict(boxstyle='round4', fc='0.8')

arrow_args = dict(arrowstyle='<-')

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()



使用决策树的分类函数:

# 对决策树 查找对应位置的值,进行分类的查询
def classify(inputTree,featLabels,testVec):

    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat

    return classLabel

classify(myTree, labels, [1,0])

输出结果: no

使用pickle模块存储决策树:

#决策树的存储
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb+')
    pickle.dump(inputTree, fw)
    fw.close()

#打开存储的决策树
def grabTree(filename):
    import pickle
    fr = open(filename,'rb')
    return pickle.load(fr)

storeTree(myTree,'classifyfierStorage.txt')

print(grabTree('classifyfierStorage.txt'))

决策树的例子:

它包含很多患者的眼部状况的观察条件以及医生推荐的隐形眼镜类型,其中隐形眼镜类型包括:硬材质(hard)、软材质(soft)和不适合佩戴隐形眼镜(no lenses) , 数据来源于UCI数据库。

young   myope  no reduced    no lenses
young  myope  no normal soft
young  myope  yes    reduced    no lenses
young  myope  yes    normal hard
young  hyper  no reduced    no lenses
young  hyper  no normal soft
young  hyper  yes    reduced    no lenses
young  hyper  yes    normal hard
pre    myope  no reduced    no lenses
pre    myope  no normal soft
pre    myope  yes    reduced    no lenses
pre    myope  yes    normal hard
pre    hyper  no reduced    no lenses
pre    hyper  no normal soft
pre    hyper  yes    reduced    no lenses
pre    hyper  yes    normal no lenses
presbyopic myope  no reduced    no lenses
presbyopic myope  no normal no lenses
presbyopic myope  yes    reduced    no lenses
presbyopic myope  yes    normal hard
presbyopic hyper  no reduced    no lenses
presbyopic hyper  no normal soft
presbyopic hyper  yes    reduced    no lenses
presbyopic hyper  yes    normal no lenses


# coding=utf-8


import test



if __name__ == '__main__':

    fr = open('lenses.txt')
    lenses = [item.strip().split('\t') for item in fr.readlines()]
    #print(lenses)

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

    myTree = test.createTree(lenses, lensesLabels)

    print(myTree)

    test.createPlot(myTree)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值