机器学习实战之决策树实现

2、决策树的构造

2.1、计算数据集的香农熵

from math import log


def calcShannonEnt(dataSet):  # 计算数据集的香农熵
    numEntries = len(dataSet)  # 数据集中实例的总数
    labelCounts = {}  # 字典键值记录当前类别出现的次数
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 当前标签为字典
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries  # 计算每个类别的概率
        shannonEnt -= prob * log(prob, 2)  # 熵值
    return shannonEnt

2.2、创建数据集

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

交互测试

import decision_tree
myDat, lables = decision_tree.createDataSet()
myDat
Out[10]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
decision_tree.calcShannonEnt(myDat)
Out[13]: 0.9709505944546686

数据集越混乱,熵越高,测试如下

myDat
Out[33]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
myDat[0][-1] = 'maybe'
decision_tree.calcShannonEnt(myDat)
Out[35]: 1.3709505944546687

2.3、按照给定特征划分数据集

def splitDataSet(dataSet, axis, value):  # 按照给定特征分割数据集(待划分数据集,划分数据集的特征,需要返回的特征的值)
    retDataSet = []  # 新的list对象
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # 接收axis之前的特征
            reducedFeatVec.extend(featVec[axis+1:])  # 将axis之后的特征进行扩展
            retDataSet.append(reducedFeatVec)
    return retDataSet  # 得到依据该特征划分后的数据子集

交互测试

importlib.reload(decision_tree)
Out[36]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
myDat, lables = decision_tree.createDataSet()
myDat
Out[38]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
decision_tree.splitDataSet(myDat, 0, 1)
Out[39]: [[1, 'yes'], [1, 'yes'], [0, 'no']]
decision_tree.splitDataSet(myDat, 0, 0)
Out[40]: [[1, 'no'], [1, 'no']]

2.4、选择最好的数据集划分方式


def chooseBeatFeatureToSplit(dataSet):  # 选取最好的数据集的划分方式
    numFeatures = len(dataSet[0]) - 1  # 特征数量
    baseEntropy = calcShannonEnt(dataSet)  # 计算数据集的香农熵
    bestInfoGain = 0.0  # 初始化最优的信息增益
    bestFeature = -1
    for i in range(numFeatures):  # 依据特征循环
        featList = [example[i] for example in dataSet]  # 第i个特征的所有内容
        uniqueVals = set(featList)  # 去重
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 数据子集
            prob = len(subDataSet)/float(len(dataSet))  # 概率
            newEntropy += prob * calcShannonEnt(subDataSet)  # 计算数据子集中每个特征的香农熵
        infoGain = baseEntropy - newEntropy  # 信息增益
        if infoGain > bestInfoGain:  # 比较当前信息增益和上一次的
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

交互测试

importlib.reload(decision_tree)
Out[48]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
myDat, lables = decision_tree.createDataSet()
decision_tree.chooseBeatFeatureToSplit(myDat)
Out[50]: 0

2.5、出现次数最多的类

def majorityCnt(classList):  # 分类次数最多的一类,参数(分类名称的列表)
    classCount = {}  # 空字典
    for vote in classList:  # 创建键值为classList中唯一的数据字典
        if vote in classCount.keys():  # 统计每个类别的出现的频率
            classCount[vote] += 0
        classCount[vote] += 1
    sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 根据频率排序
    return sortedclassCount[0][0]  # 返回出现次数最多的分类名称

2.6、创建树

def createTree(dataSet, labels):  # 创建树,参数(数据集和包含所有特征的标签列表)
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):  # 类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:  # 遍历完所有特征返回出现次数最多的
        return majorityCnt(classList)
    bestFeat = chooseBeatFeatureToSplit(dataSet)  # 选取最好的特征
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}  # 字典myTree存储了树的所有信息
    del(labels[bestFeat])  # 删除已选的的特征
    featValues = [example[bestFeat] for example in dataSet]  # 得到所有属性值
    uniqueVals = set(featValues)  # 去重
    for value in uniqueVals:
        subLabels = labels[:]  # 特征子集
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

交互测试:

importlib.reload(decision_tree)
myDat, labels = decision_tree.createDataSet()
myTree = decision_tree.createTree(myDat, labels)
myTree
Out[56]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

3、使用matplotlib注解绘制树形图

3.1、初步画图,设置基本内容结构

import matplotlib.pyplot as plt


decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 定义文本框和箭头格式
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
plt.rcParams['font.sans-serif'] = ['SimHei']  # 正常显示中文


def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 定义了实际的绘图功能
    createplot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType,
                            arrowprops=arrow_args)  # 文本内容,箭头尖端,文本位置,coords分别指定坐标系


def createplot():  # 绘制总体的图
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createplot.ax1 = plt.subplot(111, frameon=False)
    plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

初步测试
matplotlib指南https://blog.csdn.net/wizardforcel/article/details/54782628

3.2、获取叶节点的数目和树的深度并测试

def getNumLeafs(myTree):  # 获取叶节点的数目
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据是否为字典
            numLeafs += getNumLeafs(secondDict[key])  # 如果是依据子树继续数
        else:
            numLeafs += 1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据是否为字典
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth
def retrieveTree(i):  # 输出预先存储的树的信息,用于测试
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

交互测试

importlib.reload(treePlotter)
Out[55]: <module 'treePlotter' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\treePlotter.py'>
myTree = treePlotter.retrieveTree(1)
myTree
Out[57]: 
{'no surfacing': {0: 'no',
  1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
myTree = treePlotter.retrieveTree(0)
treePlotter.getNumLeafs(myTree)
Out[59]: 3
treePlotter.getTreeDepth(myTree)
Out[60]: 2

3.3组合方法绘制完整的树

def retrieveTree(i):  # 输出预先存储的树的信息,用于测试
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]


def plotMidText(cntrPt, parentPt, txtString):  # 在父子节点间填充文本信息
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  # 填充信息的x横坐标
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  # 填充信息的y纵坐标
    createplot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 计算树叶节点的数量
    depth = getTreeDepth(myTree)  # 计算树深
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  # 文本位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 填充父子节点之间的信息
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 实际绘图
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  # 减少y的偏移
    for key in secondDict.keys():  # 遍历整棵树
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createplot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()  # 清除去上图数据
    axprops = dict(xticks=[], yticks=[])
    createplot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))  # 树宽
    plotTree.totalD = float(getTreeDepth(inTree))  # 树深
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

交互测试:

importlib.reload(treePlotter)
Out[19]: <module 'treePlotter' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\treePlotter.py'>
myTree = treePlotter.retrieveTree(0)
treePlotter.createplot(myTree)

在这里插入图片描述

4、算法测试

def classify(inputTree, featLables, testVec):
    firstStr = list(inputTree.keys())[0]
    secondList = inputTree[firstStr]
    featIndex = featLables.index(firstStr)  # 将标签字符串转换为索引
    for key in secondList.keys():
        if testVec[featIndex] == key:
            if type(secondList[key]).__name__ == 'dict':
                classLabel = classify(secondList[key], featLables, testVec)
            else:
                classLabel = secondList[key]
    return classLabel
importlib.reload(decision_tree)
Out[16]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
myDat, labels = decision_tree.createDataSet()
myTree = decision_tree.retrieveTree(0)
decision_tree.classify(myTree, labels, [1, 0])
Out[19]: 'no'
decision_tree.classify(myTree, labels, [1, 1])
Out[20]: 'yes'

4.1、决策树的存储

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)
importlib.reload(decision_tree)
Out[30]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
decision_tree.grabTree('classifierStorage.txt')
Out[31]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

4.2、使用决策树预测隐形眼镜的类型

fr = open('lenses.txt')
fr
Out[36]: <_io.TextIOWrapper name='lenses.txt' mode='r' encoding='cp936'>
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = decision_tree.createTree(lenses, lensesLabels)
lensesTree
import treePlotter
Backend Qt5Agg is interactive backend. Turning interactive mode on.
treePlotter.createplot(lensesTree)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

NXU2023

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值