Machine Learning-kDtree

学会用 matplotlib 画树图

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0]) / 2.0 +cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1]) / 2.0 +cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            numLeafs += getNumLeafs(secondDict[key])
        else:numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            thisDepth = 1 + getTreeDepth(secondDict[key])

        else: thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

分析一下这几个功能函数的作用:

retrieveTree(i) : 生成一个 dict,使用retrieveTree(0)才能拿到这个 dict


getTreeDepth(myTree):

  1. 拿到一个 dict 的第一个 key

  2. 拿到前面的 key 对应的 value,是另一个 dict (secondDict)

  3. 对 secondDict 的 key 进行循环

    1. 检查 secondDict 的 key 是否是 dict (是 dict 意味着还可以进入)

      • 如果可以进入,递归检测

      • 如果不可以进入,则深度为 1(抛弃)

    2. 拿到最大深度

  4. 返回最大深度

核心步骤是检测 key 的 type 是否为 dict 和递归


getNumLeafs(myTree):

getTreeDepth(myTree) 的不同之处在于:

  1. key 的 type 不是 dict , numLeafs 就 + 1,而不是直接抛弃

  2. 最终返回的是累加的结果,而不是最大值


createTree(inTree)这是确定了 tree 的各个参数,并且绘制了图,是主函数:

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
  1. plt.figure(), num 是当前图的编号
matplotlib.pyplot.figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True, FigureClass=<class 'matplotlib.figure.Figure'>, clear=False, **kwargs)[source]
  1. fig.clf() , clear the current figure.

  2. plt.subplot(),第一个参数 111, 表示横轴的 start number 是 1, 纵轴的 start number 是 1,subplot 的序号是 1。**kwargs 是 key word arguments, 这里指定了 x 轴和 y 轴上的数据标签 list。参见https://devdocs.io/matplotlib~3.1/_as_gen/matplotlib.pyplot.subplot

  3. 拿到宽度 Width( 总 leafs ),和 Depth ( 最大 Depth )

  4. 设置x 和 y 偏移量

  5. 调用 plotTree() 绘制 tree


在看 plotTree() 之前,我们先看一看它的两个子函数:plotMidText()plotNode()

plotMidTxt(cntrPt, parentPt, txtString):

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0]) / 2.0 +cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1]) / 2.0 +cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

xMid 算的是 x 轴的 parentPt 和 cntrPt 的中点坐标,同理算了 y 轴的中点坐标,然后调用 c把文字添加到相应坐标位置(父子 point 的连线中点)

plotNode(nodeTxt, centerPt, parentPt, nodeType):

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

annotate 单词自身的意思是注释,本身是为 plot 添加注释,但是 built-in 的工具可以让你把文字画到 plot 里面

  • nodeTxt 就是将要显示的文字

  • xy = parentPt 表示将要注释的 point 的坐标

  • xycoords = ‘axes fraction’ 表示按照比例(而不是像素值)从轴(而不是整张图片)的左下角开始来绘点

  • xytext = centerPt 表示注释的文字的位置

  • textcoords=‘axes fraction’ 应当是文字的绘制方法,和坐标的类似

  • va=‘center’ 应当是 vertical align,类似地 ha 是 horizon align

  • bbox=nodeType,bbox 属性自身是方块(就是那个节点)的样式,是 dict 类型,而我们的两个 dict 分别是 decisionNode 和 lefaNode,这两个 dict 在最开始的时候便定义好了

  • arrow_args 同上,是箭头样式

这个 plotNode() 函数绘制的是一个箭头加上一个 Node,类似于这样:
在这里插入图片描述
到了这里可能大家对 centerPt 和 parentPt 的意义不太理解了,而且对于前面 xOff 以及 yOff 也不太理解,我也一样。

这个我们等会儿到调用它们的时候再看


plotTree(myTree, parentPt, nodeText)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

先看 cntrPt : 它是一个二维元组,它的两个值是当前 decisionNode 的位置:

先看 (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW)

plotTree.xOff = -0.5 / plotTree.totalW

plotTree.totalW = float(getNumLeafs(inTree))
在这里插入图片描述

  1. totalW 是总的叶节点个数,再上图里面, leafNode 的个数其实决定了整个图有多宽

  2. xOff 是偏移量,向左偏移 0.5 / leafNodesNum

    正常情况下我们会使用 1 / number 来均分 x 轴宽度,但是那样会使图像偏左(假如 3 个节点,那三个坐标分别是 1/3, 2/3, 3/3,起始点在 x 轴右边,因此需要加上一个向左的偏移量,移动多少呢? 不能直接又向左移动 1/3,因此移动一半,这样整个图像在 x 轴上才能位于图像中间)

  3. plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW 是当前 decisionNode 的 x 坐标:它位于它的子节点的中央位置

参考博客:https://www.cnblogs.com/fantasy01/p/4595902.html

首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:

1、其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,

2、对于plotTree函数中的红色部分即如下:

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW

整体来说:

  1. 先拿到所有的 leafNode 数,把整个图像宽度均分为这么多份

  2. 先画出此 decisionNode(坐标由子 node 数和 depth 确定)

  3. 对 Tree dict 的 keys 进行遍历

    1. 如果这个 key 的 value 的 type name 是 dict,递归进去

    2. 如果这个 key 的 value 的 type name 不是 dict, 画出它的这个子 node(宽度总是用)

  4. 恢复 yOff 值

另外还需要解释一下初始时的 plotTree(inTree, (0.5, 1.0), ‘’), 这是因为我们虚拟了一个父 node (顶级 Node 的父 node),它的父 node 和它自身(位置)重合,但是没有内容。

看了这么多绘制 tree 的内容,我们的核心仍然是 classify


回到 tree.py , 我们的主要任务就变成了怎样根据一堆数据生成一个 dict,然后供给 treePlotter 来绘图

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

这个函数用来计算熵值,熵值会在后面被用于判断哪个属性用来做分类是最合适的。

  1. 先拿到整个 dataSet,类型为 list 的长度

  2. 对 dataSet 中的每一项:

    1. 先拿到最后一项(应该是一项属性值)如果 labelCounts 这个 dict 里面没有这个属性,加上
  3. 相应的属性值 +1

  4. 对 labelCounts 中的每一个元素:

    1. 算出具有这个属性的元素被选中的可能性: probablity(xi)

    2. log2(p(xi)) 称为 information

    3. 熵值 Entropy 就是 information 的期望值

    4. 因此熵值的最终计算公式是
      在这里插入图片描述

  5. 最后返回熵值


def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

这只是简单的列表操作:

根据某个 axis 上的值,把所有的元素分为值是 value 的和值不是 value 的。


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range (numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob *calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

(这里假定了数据的一些内容:最后一列是 label)

  1. 先拿到 features 的个数

  2. 计算基础 Entropy(整个 dataSet 的熵,没有被分类过的情况下)

  3. 循环进每一个 feature:

    1. 拿到此 feature 的所有不同值( set 特性)

    2. 循环进每一个 value:

      1. 根据这个 feature 的这个 value 进行分割

      2. 计算新的 entropy

      3. 算出这个 feature 的各个 value 的 entropy 的和

    3. 计算出这个 feature 的 information gain :所有 entropy 之差

总之只要 按照这个 feature 分割之后的 dataSet 的 entropy 之和最小,那么这个 feature 就是 bestFeature,最后返回的是 bestFeature 的 index


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount: classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount

上面的 split 函数可能出现的一个问题是,当跑完所有的 value 之后发现还是有一些元素没法被分类出来(比如某个数据的某个 feature 的 value 缺失,那么它将无法被分类出来)

因此需要确定怎样算是分类结束,于是我们选择了只做二分,不做多分(每一个 feature 只判断一个 value)

这个 value 就是这个 feture 之下出现次数最多的那个


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

这是真正创建了 Tree 的函数:

  1. 拿到 dataSet 的最后一列( label )

  2. 如果所有的 labels 都相同:

    • 直接返回这个 label
  3. 如果 dataSet 只有一个 feature:

    • 返回这个 feature 出现次数最多的那个 value
  4. 拿到 best feature 的 index

  5. 拿到上面的 index 对应的 label(这个 label 是参数中 labels 的)

  6. 删掉参数 labels 中的 best feature 项

  7. 拿到 dataSet 里 best feature 对应的所有 value 并且去重

  8. 拿到 value (这个 vakue 可能还是一个 dict )之后递归


def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

最终又回到了我们的 classify 函数,这里直接用的是 Tree(实际上是深层 dict ) 来做 classify

  1. 先拿到 tree 的第一个 key

  2. 判断 testVec 的各个 feature 是否能被分类进 tree

  3. 递归

我们的 k-Dtree 算法大概就到这儿了,核心是创建一个 tree 出来

condDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

最终又回到了我们的 classify 函数,这里直接用的是 Tree(实际上是深层 dict ) 来做 classify

  1. 先拿到 tree 的第一个 key

  2. 判断 testVec 的各个 feature 是否能被分类进 tree

  3. 递归

我们的 k-Dtree 算法大概就到这儿了,核心是创建一个 tree 出来

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值