机器学习实战 —— 决策树(完整代码)


声明: 此笔记是学习《机器学习实战》 —— Peter Harrington 上的实例并结合西瓜书上的理论知识来完成,使用Python3 ,会与书上一些地方不一样。

机器学习实战—— 决策树

Coding: Jungle


样本集合: D D D

第k类样本所占比例L: p k p_k pk

属性a对样本D进行划分产生分支节点个数: V V V

信息熵 : E n t ( D ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 p k Ent(D) = - \sum_{k=1}^{|y|} p_k log_2p_k Ent(D)=k=1ypklog2pk

信息增益: G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a) = Ent(D) - \sum_{v = 1}^{V} \frac{|D^v|}{|D|}Ent(D^v) Gain(D,a)=Ent(D)v=1VDDvEnt(Dv)

数据集

不浮出水面是否可以生存是否有脚蹼是否属于鱼类
1
2
3
4
4
1. 计算给定数据集的熵
#trees.py
from math import log
def calShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    #为所有可能的分类创建字典
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        #计算熵,先求p
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob *log(prob,2)
    return shannonEnt
2. 构建数据集
def creatDataSet():
    dataSet = [[1,1,'maybe'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels
myData,labels = creatDataSet()
print("数据集:{}\n 标签:{}".format(myData,labels))
print("该数据集下的香农熵为:{}".format(calShannonEnt(myData)))

数据集:[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
 标签:['no surfacing', 'flippers']
该数据集下的香农熵为:1.3709505944546687

相同数据量下,减少属性值类型及特征值,对比熵的变化

myData[0][-1] = 'yes'
print("数据为:{}\n 该数据集下的香农熵为:{}".format(myData,calShannonEnt(myData)))
数据为:[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
 该数据集下的香农熵为:0.9709505944546686
3. 划分数据集
# 根据属性及其属性值划分数据集
def splitDataSet(dataSet, axis, value):
    '''dataSet : 待划分的数据集
        axis : 属性及特征
        value : 属性值及特征的hasattr值'''
    retDataSet = []
    for featVet in dataSet:
        if featVet[axis] == value:
            reducedFeatVec = featVet[:axis]
            reducedFeatVec.extend(featVet[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
print("划分前的数据集:{}\n \n按照“离开水是否能生存”为划分属性,得到下一层待划分的结果为:\n{}--------{}".format(myData,splitDataSet(myData,0,0),splitDataSet(myData,0,1)))
划分前的数据集:[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
 
按照“离开水是否能生存”为划分属性,得到下一层待划分的结果为:
[[1, 'no'], [1, 'no']]--------[[1, 'yes'], [1, 'yes'], [0, 'no']]
# 选择最好的数据集划分方式,及根绝信息增益选择划分属性
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calShannonEnt(dataSet)
    bestInfoGain, bestFeature = 0, -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
chooseBestFeatureToSplit(myData)
0

递归构建决策树

# 找到出现次数最多的分类名称
import operator


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
        sortedClassCount = sorted(
            classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]
# 创建树的函数
def creatTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    # 类别完全相同停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        sublabels = labels[:] 
        myTree[bestFeatLabel][value] = creatTree(
            splitDataSet(dataSet, bestFeat, value), sublabels)
    return myTree
creatTree(myData,labels)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

画出决策树

# treePlotter.py
import matplotlib.pyplot as plt
from pylab import*

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt,parentPt, nodeType):
    mpl.rcParams['font.sans-serif']=['SimHei']       
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    





def createPlot():
    fig = plt.figure(111, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode(U'Decision Node',  (0.5, 0.1),(0.1, 0.5), decisionNode)
    plotNode(U'Leaf Node', (0.8, 0.1),(0.3, 0.8),  leafNode)
    plt.show
createPlot()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2QZJ7Upl-1596631929903)(output_19_0.png)]

# treePlotter.py
# 计算叶节点的个数

def getNumLeaves(myTree):
    numLeafs= 0
    # 截取到树字典中的key值
    #firstStr = str(myTree.keys())[13:-3]
    firstStr =  eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeaves(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
# 计算树的深度


def getTreeDepth(myTree):
    maxDepth = 0
    firstStr =  eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

#测试深度计算和叶节点记述函数
def retrieveTree(i):
    listOftrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOftrees[i]

mytree =  retrieveTree(1)
getNumLeaves(mytree)
getTreeDepth(mytree)
3
# treePlotter.py
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getTreeDepth(myTree)
    firstStr =  eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]
    cntrPt = (plotTree.xOff+(1.0 + float(numLeafs))/
              2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff +1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks = [],yticks = [])
    createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
    plotTree.totalW = float(getNumLeaves(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),"")
    plt.show
myTree = retrieveTree(0)
createPlot(myTree)

在这里插入图片描述

测试功能

#trees.py
def classify(inputTree,featLbabels,testVec):
    firstStr =  eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]
    secondDict = inputTree[firstStr]
    featIndex = featLbabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key],featLbabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel
#来测试一下
myDat1,labels1 = creatDataSet()
mytree1 = retrieveTree(0)
classify(mytree1,labels1,[0,0])
'no'

西瓜书西瓜的决策树构建

数据集

编号色泽根蒂敲声纹理脐部触感瓜型
1青绿蜷缩浊响清晰凹陷硬滑好瓜
2乌黑蜷缩沉闷清晰凹陷硬滑好瓜
3乌黑蜷缩浊响清晰凹陷硬滑好瓜
4青绿蜷缩沉闷清晰凹陷硬滑好瓜
5浅白蜷缩浊响清晰凹陷硬滑好瓜
6青绿稍蜷浊响清晰稍凹软粘好瓜
7乌黑稍蜷浊响稍糊稍凹软粘好瓜
8乌黑稍蜷浊响清晰稍凹硬滑好瓜
9乌黑稍蜷沉闷稍糊稍凹硬滑坏瓜
10青绿硬挺清脆清晰平坦软粘坏瓜
11浅白硬挺清脆模糊平坦硬滑坏瓜
12浅白蜷缩浊响模糊平坦软粘坏瓜
13青绿稍蜷浊响稍糊凹陷硬滑坏瓜
14浅白稍蜷沉闷稍糊凹陷硬滑坏瓜
15乌黑稍蜷浊响清晰稍凹软粘坏瓜
16浅白蜷缩浊响模糊平坦硬滑坏瓜
17青绿蜷缩沉闷稍糊稍凹硬滑坏瓜
#DT_ID3_pumpkin .py
def createDatePumpKin():
    dataSet = [
    # 1
    ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
    # 2
    ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
    # 3
    ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
    # 4
    ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
    # 5
    ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
    # 6
    ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
    # 7
    ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
    # 8
    ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],

    # ----------------------------------------------------
    # 9
    ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
    # 10
    ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
    # 11
    ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
    # 12
    ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
    # 13
    ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
    # 14
    ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
    # 15
    ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
    # 16
    ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
    # 17
    ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]
    labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
    return dataSet,labels


#这里在前面计算香农熵的函数,好像是会在log函数的第二个参数的类型上报错,进行一下修改,但是主要问题原因需要看底层的代码
from math import log2
def calShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    #为所有可能的分类创建字典
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        #计算熵,先求p
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob *log2(prob)
    return shannonEnt
myData_P,labels_P = createDatePumpKin()
print("该数据集下的香农熵为:{}".format(myData_P,calShannonEnt(myData_P)))

 该数据集下的香农熵为:0.9975025463691153
print("按照“色泽”为划分属性,得到下一层待划分的结果为:\n-------->{}\n-------->{}\n--------->{}".format(myData_P,splitDataSet(myData_P,0,'浅白'),splitDataSet(myData_P,0,'青绿'),splitDataSet(myData_P,0,'乌黑')))
按照“色泽”为划分属性,得到下一层待划分的结果为:
-------->[['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'], ['蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'], ['稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'], ['蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜']]
-------->[['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], ['稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'], ['硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'], ['稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'], ['蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']]
--------->[['蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], ['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'], ['稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'], ['稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'], ['稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜']]
myTree_P=creatTree(myData_P,labels_P)
print(myTree_P)
createPlot(myTree_P)
{'纹理': {'模糊': '坏瓜', '稍糊': {'触感': {'硬滑': '坏瓜', '软粘': '好瓜'}}, '清晰': {'根蒂': {'稍蜷': {'色泽': {'青绿': '好瓜', '乌黑': {'触感': {'硬滑': '好瓜', '软粘': '坏瓜'}}}}, '硬挺': '坏瓜', '蜷缩': '好瓜'}}}}

dt!

  • 3
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值