机器学习算法--(二)决策树--代码详解

 决策树

目录

 决策树

1.计算香农熵:​

 2.按照给定特征划分数据集

3.寻找最好的数据集划分方式

4.多数表决处理非唯一类别分组

5.创建树

6.绘制注解树节点

7.获取叶节点的数目和树的层数

8.绘制

9.完整代码


 

1.计算香农熵:

from math import log
"""
计算数据的香农熵
"""
def calcshannonEnt(dataSet):
    # 计算数据总数
    numEntries = len(dataSet)
    labelCounts = {}
    # 创建所有可能分类
    for featVec in dataSet:
        currentlabel = featVec[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 0
        labelCounts[currentlabel] += 1
    # 计算香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

def createdataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

# 测试计算香农熵
myDat, labels = createdataSet()
print(myDat)
print(labels)
print(calcshannonEnt(myDat))
print("------------------------------------------------------------")
# 混合数据越多,熵越高
myDat[0][-1] = 'maybe'
myDat[2][-1] = 'possible'
myDat[3][-1] = 'certain'
print(myDat)
print(calcshannonEnt(myDat))

 

 

 2.按照给定特征划分数据集

"""
按照给定特征划分数据集
"""

# 函数三个参数分别代表:待划分的数据集,划分数据集的特征,特征的返回值
def splitDataSet(dataSet, axis, value):
    reDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[: axis]
            reduceFeatVec.extend(featVec[axis+1:])
            reDataSet.append(reduceFeatVec)
    return reDataSet


myDat,labels = createdataSet()
print(myDat)
print("================================================")
print(splitDataSet(myDat, 0, 1))
print(splitDataSet(myDat, 0, 0))

3.寻找最好的数据集划分方式


"""
遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的划分方式
"""

def chooseBestFeatureToSplit(dataSet):
    # 求第一行有多少列的Feature
    numFeatures = len(dataSet[0]) - 1
    # 计算没有经过划分的数据的香农熵
    baseEntroy = calcshannonEnt(dataSet)
    # 最优的信息增益,最优的Feature编号
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        ## 创建唯一的分类标签列表,获取第i个的所有特征(信息元素纵列)

        # 将dataSet 中的数据先按行依次放入example中,然后取得example[i]元素,放入列表featList
        featList = [example[i] for example in dataSet]
        # 使用set集,排除featList中重复的标签,得到唯一分类的集合
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 遍历档次uniqueVals中所有的标签value
        for value in uniqueVals:
            # 对第i个数据划分数据集,返回所有包含i的数据(去掉第i个特征)
            subDataSet = splitDataSet(dataSet, i, value)
            # 计算包含i的数据占总数据的百分比
            prob = len(subDataSet) / float(len(dataSet))
            # 计算新的香农熵,不断进行迭代,这个计算过程仅在包含指定特征标签子集中进行
            newEntropy += prob * calcshannonEnt(subDataSet)
        # 计算信息增益
        infoGain = baseEntroy - newEntropy
        if (infoGain > bestInfoGain):
            # 更新信息增益
            bestInfoGain = infoGain
            # 确定最优的增益的特征索引
            bestFeature = i
    # 返回最优增益的索引
    return bestFeature


myDat, labels = createdataSet()
print(chooseBestFeatureToSplit(myDat))
print(myDat)

4.多数表决处理非唯一类别分组


"""
遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组,采用多数表决法
"""

import operator
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

 

5.创建树

"""
创建树的函数代码
"""
def createTree(dataSet, labels):
    # 返回当前数据集下标签列所有值
    classList = [example[-1] for example in dataSet]
    # 当类别完全相同时则停止继续划分,直接返回该类的标签(决策树构造完成)
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 遍历完所有特征时,仍然不能将数据集划分成仅包含唯一类别的分组,返回出现次数最多的类别标签作为返回值
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # 获取最好分类特征索引
    bestFeat = chooseBestFeatureToSplit(dataSet)
    # 获取该特征的名字
    bestFeatLabel = labels[bestFeat]


    # 这里直接使用字典变量来存储树信息,用于绘制树形图
    myTree = {bestFeatLabel: {}}
    # 删除已经在选取的特征
    del(labels[bestFeat])

    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # 复制所有的标签
        # 递归调用自身
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


myDat, labels = createdataSet()
myTree = createTree(myDat, labels)
print(myTree)

6.绘制注解树节点

"""
使用文本绘制注解树节点
"""

import matplotlib.pyplot as plt

# 决策点的属性,boxstyle是文本框类型,sawtooth是锯齿形,fc是边框粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# 决策树叶子节点的属性
leafNode = dict(boxstyle="round4", fc="0.8")
# 箭头的属性
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
    # nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的点,xy是箭头尖的坐标,xytext设置注释内容显示的中心位置
    # xycoords和textcoords是坐标xy 与 xytext 的说明(按轴坐标), 若textcoords=None, 则默认textcoords与xycoords相同,若都未设置,则默认为data
    # va/ha设置节点框中文字的位置,va为纵向取值为(u'top',u'bottom',u'center',u'baseline'),ha为纵向取值为(u'center',u'right',u'left')


def createPlot():
    # 创建一个画布,背景为白色
    fig = plt.figure(1, facecolor='white')
    fig.clf()  # 画布清空
    # ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
    createPlot.ax1 = plt.subplot(111, frameon=True) # frameon表示是否绘制坐标轴矩形
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()


createPlot()

7.获取叶节点的数目和树的层数


"""
获取叶节点的数目和树的层数
"""
def getNumLeafs(myTree):
    # 初始化节点数
    numLeafs = 0
    firstside = list(myTree.keys())
    # 找到输入的第一个元素,第一个关键词划分数据类别的标签
    firstStr = firstside[0]
    secondDict = myTree[firstStr]
    # 测试数据是否为字典形式
    for key in secondDict.keys():
        # type判断子节点是否为字典类型
        if type(secondDict[key]).__name__ == 'dict':
            # 若子节点也是字典,则也是判断节点,需要递归获取num
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    # 返回整棵树的节点数
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstside = list(myTree.keys())
    firstStr = firstside[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


"""
输出预先存储的树信息,避免每次测试都需要重新创建树
"""
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {'head': {0: 'no', 1: 'yes'}}}}}
                   ]
    return listOfTrees[i]


myTree = retrieveTree(0)
print(getTreeDepth(myTree))
print(getNumLeafs(myTree))
print(retrieveTree(0))
print(retrieveTree(1))

 

8.绘制

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    # 计算树的宽度 totalW
    numLeafs = getNumLeafs(myTree)
    # 计算树的高度 totalD
    depth = getTreeDepth(myTree)
    firstside = list(myTree.keys())
    firstStr = firstside[0]  # 找到输入的第一个元素
    # 按照叶子节点个数划分x轴
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    # 标注节点属性
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # y方向上的摆放位置自下而上绘制,因此递减y值
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        # 判断是否为字典 不是则为叶子节点
        if type(secondDict[key]).__name__ == 'dict':
            # 递归继续向下找
            plotTree(secondDict[key], cntrPt, str(key))
        else:  # 为叶子节点
            # x方向计算节点坐标
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # 绘制
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 添加文本信息
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    # 下次重新调用时回恢复y
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

# 主函数
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


myTree = retrieveTree(0)
# createPlot(myTree)
myTree['no surfacing'][3] = 'maybe'
print(myTree)
createPlot(myTree)

9.完整代码

from math import log
"""
计算数据的香农熵
"""
def calcshannonEnt(dataSet):
    # 计算数据总数
    numEntries = len(dataSet)
    labelCounts = {}
    # 创建所有可能分类
    for featVec in dataSet:
        currentlabels = featVec[-1]
        if currentlabels not in labelCounts.keys():
            labelCounts[currentlabels] = 0
        labelCounts[currentlabels] += 1
    # 计算香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels


# myDat,labels = createDataSet()
#
#
# myDat[0][-1] = 'maybe'
# myDat[2][-1] = 'possible'
# myDat[3][-1] = 'certain'
# print(myDat)
# print(labels)
# shannonEnt = calcshannonEnt(myDat)
# print(shannonEnt)


"""熵越高,则混合的数据越多"""


"""
按照给定特征划分数据集
"""

# 函数三个参数分别代表:待划分的数据集,划分数据集的特征, 特征的返回值
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet

# myDat, labels = creatDataSet()
# print(myDat)
# print(splitDataSet(myDat, 0, 1))
# print(splitDataSet(myDat, 0, 0))


"""
遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的划分方式
"""
def chooseBestFeatureToSplit(dataSet):
    # 求第一行有多少列的Feature,减去1,是因为最后一列是label列
    numFeatures = len(dataSet[0]) - 1
    # 计算没有经过划分的数据的香农熵
    baseEntropy = calcshannonEnt(dataSet)
    # 最优的信息增益值; 最优的Feature编号
    bestInfoGain = 0.0 ; bestFeature = -1
    for i in range(numFeatures):
        ## 创建唯一的分类标签列表,获取第i个的所有特征(信息元素纵列!)

        # 将dataSet 中的数据先按行依次放入example中,然后取得example中的example[i]元素,放入列表featList中
        featList = [example[i] for example in dataSet]
        # 使用set集, 排除featList中重复的标签,得到唯一分类的集合
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 遍历当次uniqueVals中所有的标签value
        for value in uniqueVals:
            # 对第i个数据划分数据集,返回所有包含i的数据(已排除第i个特征)
            subDataSet = splitDataSet(dataSet, i, value)
            # 计算包含i的数据占总数据的百分比
            prob = len(subDataSet) / float(len(dataSet))
            # 计算新的香农熵,不断进行迭代,这个计算过程仅在包含指定特征标签子集中进行
            newEntropy += prob * calcshannonEnt(subDataSet)
        # 计算信息增益
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            # 更新信息增益
            bestInfoGain = infoGain
            # 确定最优增益的特征索引
            bestFeature = i
            # 更新最优增益
    # 返回最优增益的索引
    return bestFeature


# myDat, labels = creatDataSet()
# print(chooseBestFeatureToSplit(myDat))
# print(myDat)

"""
遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组,采用多数表决法
"""
import operator
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


"""
创建树的函数代码
"""
def createTree(dataSet, labels):
    # 返回当前数据集下标签列所有值
    classList = [example[-1] for example in dataSet]
    # 当类别完全相同时则停止继续划分,直接返回该类的标签(决策树构造完成)
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    if len(dataSet[0]) == 1:         # 遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组dataSet
        return majorityCnt(classList)  # 由于无法简单的返回唯一的类标签,这里就返回出现次数最多的类别作为返回值
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 获取最好的分类特征索引
    bestFeatLabel = labels[bestFeat]  # 获取该特征的名字

    # 这里直接使用字典变量来存储树信息,这对于绘制树形图很重要
    myTree = {bestFeatLabel: {}}  # 当前数据集选取最好的特征存储在bestFeat中
    del(labels[bestFeat])  # 删除已经在选取的特征
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # 复制所有的标签
        # 递归调用自身
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

# myDat, labels = creatDataSet()
# myTree = createTree(myDat, labels)
# print(myTree)


"""
使用文本绘制注解树节点
"""

import matplotlib.pyplot as plt

# 决策点的属性,boxstyle是文本框类型,sawtooth是锯齿形,fc是边框粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# 决策树叶子节点的属性
leafNode = dict(boxstyle="round4", fc="0.8")
# 箭头的属性
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
    # nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的点,xy是箭头尖的坐标,xytext设置注释内容显示的中心位置
    # xycoords和textcoords是坐标xy 与 xytext 的说明(按轴坐标), 若textcoords=None, 则默认textcoords与xycoords相同,若都未设置,则默认为data
    # va/ha设置节点框中文字的位置,va为纵向取值为(u'top',u'bottom',u'center',u'baseline'),ha为纵向取值为(u'center',u'right',u'left')


# def createPlot():
#     # 创建一个画布,背景为白色
#     fig = plt.figure(1, facecolor='white')
#     fig.clf()  # 画布清空
#     # ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
#     createPlot.ax1 = plt.subplot(111, frameon=True) # frameon表示是否绘制坐标轴矩形
#     plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#     plotNode('a leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
#     plt.show()


# createPlot()

"""
获取叶节点的数目和树的层数
"""

def getNumLeafs(myTree):
    # 初始化节点数
    numLeafs = 0
    # python3替换注释的两行代码
    firstside = list(myTree.keys())
    firstStr = firstside[0]  # 找到输入的第一个元素,第一个关键词划分数据类别的标签
    secondDict = myTree[firstStr]
    # firstStr = myTree.keys()[]
    # secondDict = myTree[firstStr]
    for key in secondDict.keys():  # 测试数据是否为字典形式
        # type判断子节点是否为字典类型
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
            # 若子节点也为字典,则也是判断节点,需要递归获取num
        else:
            numLeafs += 1
    # 返回整棵树的节点数
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstside = list(myTree.keys())
    firstStr = firstside[0]
    secondDict = myTree[firstStr]
    # firstStr = myTree.keys()[0]
    # secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


# 输出预先存储的树信息,避免每次测试都需要重新创建树
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                   ]
    return listOfTrees[i]

#
# print(retrieveTree(1))
# myTree = retrieveTree(0)
# print(getNumLeafs(myTree))
# print(getTreeDepth(myTree))
#

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    # 计算树的宽度 totalW
    numLeafs = getNumLeafs(myTree)
    # 计算树的高度 totalD
    depth = getTreeDepth(myTree)
    firstside = list(myTree.keys())
    firstStr = firstside[0]  # 找到输入的第一个元素
    # 按照叶子节点个数划分x轴
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    # 标注节点属性
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # y方向上的摆放位置自下而上绘制,因此递减y值
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        # 判断是否为字典 不是则为叶子节点
        if type(secondDict[key]).__name__ == 'dict':
            # 递归继续向下找
            plotTree(secondDict[key], cntrPt, str(key))
        else:  # 为叶子节点
            # x方向计算节点坐标
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # 绘制
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 添加文本信息
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    # 下次重新调用时回恢复y
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

# 主函数
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


myTree = retrieveTree(0)
# createPlot(myTree)
myTree['no surfacing'][3] = 'maybe'
print(myTree)
createPlot(myTree)

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值