机器学习实战第三章,决策树的实现

新建tree.py模块,写入下列代码,这里的所有函数共同完成了建立一个决策树

from math import log
import numpy as np
import matplotlib as plt
import operator


def calcShannonEnt(dataSet):
    # 计算给定数据的香农熵
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        labelCounts[currentLabel] = labelCounts.get(currentLabel,0) + 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

def createDataSet():
    #产生一个样本集
    dateSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing', 'fippers']
    return dateSet,labels


def splitDataSet(dataSet, axis, value):
    #按照value的值划分数据集
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    #选择最好的数据集划分方式,返回特征索引
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeatures = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob = float(len(subDataSet))/float(len(dataSet))
            newEntropy += prob*calcShannonEnt(subDataSet)
        infogain = baseEntropy - newEntropy
        if(infogain > bestInfoGain):
            bestInfoGain = infogain
            bestFeatures = i
    return bestFeatures

def majorityCnt(classList):
    #当用于划分的特征已经遍历完但是分类还是不唯一,这时就
    #只能强制他们变成一个分类,选择出现频率最高的作为分类
    classCount = {}
    for vote in classList:
        classCount[vote] = classCount.get(vote,0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.getitem(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels):
    #构建决策树,字典表示树结构
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    mytree = {bestFeatLabel:{}}
    subLabels = labels[:bestFeat]
    subLabels.extend(labels[bestFeat+1:])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVales = set(featValues)
    for value in uniqueVales:
        subDataSet = splitDataSet(dataSet, bestFeat, value)
        mytree[bestFeatLabel][value] = createTree(subDataSet,subLabels)
    return mytree



def classfy(inputTree, featLabels,testVec):
    '''
    featLabel是各个特征的名字,testVec是特征向量,inputTree是决策树,返回分类
    :param inputTree:
    :param featLabels:
    :param testVec:
    :return:
    '''
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict:
        if key == testVec[featIndex]:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classfy(secondDict[key], featLabels,testVec)
            else: classLabel = secondDict[key]
    return classLabel

def storeTree(inputTree, filename):
    '''
    序列化决策树,存入文件
    :param inputTree:
    :param filename:
    :return:
    '''
    import pickle
    fw = open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    '''
    将文件转换为决策树到内存
    :param filename:
    :return:
    '''
    import pickle
    fr = open(filename,'r')
    return pickle.load(fr)

createDataSet函数已经表明,这里样本集dataSet应该定义成元素是列表的列表,并且元素列表中的最后一列是类别,前面几列是特征

labels是对应的特征列的特征名。


下面是treePlot.py模块,用于绘制决策树

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle='sawtooth', fc='10')
leafNode = dict(boxstyle='round4',fc='0.8')
arrow_args = dict(arrowstyle='<|-')



def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
                             xytext=centerPt,textcoords='axes fraction',\
                             va='center', ha='center',bbox=nodeType,arrowprops\
                             =arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict:
        if(type(secondDict[key]).__name__ == 'dict'):
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return  numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict:
        if(type(secondDict[key]).__name__ == 'dict'):
            thisDepth = 1+getTreeDepth((secondDict[key]))
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    #预先设置树的信息
    listOfTree = [{'no surfacing':{0:'no', 1:{'flipper':{0:'no', 1:'yes'}}}},
                  {'no surfacing':{0:'no', 1:{'flipper':{0:{'head':{0:'no', 1:'yes'}},1:'no'}}}},
                  {'a1':{0:'b1', 1:{'b2':{0:{'c1':{0:'d1',1:'d2'}}, 1:'c2'}}, 2:'b3'}}]
    return listOfTree[i]

def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks = [0.2,0.4,0.6], yticks=[0.2,0.4,0.6])
    createPlot.ax1 = plt.subplot(111)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0), '')
    plt.show()

def plotMidText(cntrPt, parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt = (0.5,1.0), nodeTxt = ""):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,\
              plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict:
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\
                     cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
对treePlot.py模块的具体讲解: 点击打开链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值