k近邻法之kd树

k近邻算法

给定一个训练数据集,对新的输入实例,在训练数据集中找到跟它最近的k个实例,根据这k个实例的类判断它自己的类(一般采用多数表决的方法)。

这里写图片描述


k近邻模型

模型有3个要素——距离度量方法、k值的选择和分类决策规则。

模型

当3要素确定的时候,对任何实例(训练或输入),它所属的类都是确定的,相当于将特征空间分为一些子空间。

这里写图片描述

距离度量

对n维实数向量空间 R n R^n Rn,经常用 L p L_p Lp距离或曼哈顿Minkowski距离。

L p L_p Lp距离定义如下:

这里写图片描述

当p=2时,称为欧氏距离:

这里写图片描述

当p=1时,称为曼哈顿距离:

这里写图片描述

当p=∞,它是各个坐标距离的最大值,即:

这里写图片描述

用图表示如下:

这里写图片描述

k值的选择

k较小,容易被噪声影响,发生过拟合。

k较大,较远的训练实例也会对预测起作用,容易发生错误。

分类决策规则

使用0-1损失函数衡量,那么误分类率是:

这里写图片描述

Nk是近邻集合,要使左边最小,右边的

这里写图片描述

必须最大,所以多数表决=经验最小化。


k近邻法的实现:kd树

算法核心在于怎么快速搜索k个近邻出来,朴素做法是线性扫描,不可取,这里介绍的方法是kd树。

构造kd树

对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:

找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S’,++i,node=current),直到不可再分。

例子

这里写图片描述

Python代码

短短几行即可搞定:

  1. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  2.  
  3. class node:
  4.     def __init__(self, point):
  5.         self.left = None
  6.         self.right = None
  7.         self.point = point
  8.         pass
  9.     
  10. def median(lst):
  11.     m = len(lst) / 2
  12.     return lst[m], m
  13.  
  14. def build_kdtree(data, d):
  15.     data = sorted(data, key=lambda x: x[d])
  16.     p, m = median(data)
  17.     tree = node(p)
  18.  
  19.     del data[m]
  20.     print data, p
  21.  
  22.     if m > 0:tree.left = build_kdtree(data[:m], not d)
  23.     if len(data) > 1:tree.right = build_kdtree(data[m:], not d)
  24.     return tree
  25.  
  26. kd_tree =build_kdtree(T, 0)
  27. print kd_tree
    可视化

    可视化的话则要费点功夫保存中间结果,并恰当地展示出来

    1. # -*-coding:utf-8 -*-
    2. #Filename: kdtree.py
    3. # Author: hankcs
    4. # Date:2015/2/4 15:01
    5. import copy
    6. import itertools
    7. from matplotlib import pyplot as plt
    8. from matplotlib.patches import Rectangle
    9. from matplotlib import animation
    10.  
    11. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
    12.  
    13.  
    14. def draw_point(data):
    15.     X, Y = [], []
    16.     for p in data:
    17.         X.append(p[0])
    18.         Y.append(p[1])
    19.     plt.plot(X, Y, 'bo')
    20.  
    21.  
    22. def draw_line(xy_list):
    23.     for xy in xy_list:
    24.         x, y = xy
    25.         plt.plot(x, y, 'g', lw=2)
    26.  
    27.  
    28. def draw_square(square_list):
    29.     currentAxis = plt.gca()
    30.     colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
    31.     for square in square_list:
    32.         currentAxis.add_patch(
    33.             Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
    34.                       color=next(colors)))
    35.  
    36.  
    37. def median(lst):
    38.     m = len(lst) / 2
    39.     return lst[m], m
    40.  
    41.  
    42. history_quare= []
    43.  
    44.  
    45. def build_kdtree(data, d, square):
    46.     history_quare.append(square)
    47.     data = sorted(data, key=lambda x: x[d])
    48.     p, m = median(data)
    49.  
    50.     del data[m]
    51.     print data, p
    52.  
    53.     if m >= 0:
    54.         sub_square = copy.deepcopy(square)
    55.         if d == 0:
    56.             sub_square[1][0] = p[0]
    57.         else:
    58.             sub_square[1][1] = p[1]
    59.         history_quare.append(sub_square)
    60.         if m > 0:build_kdtree(data[:m], not d,sub_square)
    61.     if len(data) > 1:
    62.         sub_square = copy.deepcopy(square)
    63.         if d == 0:
    64.             sub_square[0][0] = p[0]
    65.         else:
    66.             sub_square[0][1] = p[1]
    67.         build_kdtree(data[m:], not d,sub_square)
    68.  
    69.  
    70. build_kdtree(T, 0, [[0, 0], [10, 10]])
    71. print history_quare
    72.  
    73.  
    74. # drawan animation to show how it works, the data comes from history
    75. # firstset up the figure, the axis, and the plot element we want to animate
    76. fig =plt.figure()
    77. ax =plt.axes(xlim=(0, 2), ylim=(-2, 2))
    78. line, = ax.plot([], [], 'g', lw=2)
    79. label = ax.text([], [], '')
    80.  
    81. #initialization function: plot the background of each frame
    82. def init():
    83.     plt.axis([0, 10, 0, 10])
    84.     plt.grid(True)
    85.     plt.xlabel('x_1')
    86.     plt.ylabel('x_2')
    87.     plt.title('build kd tree (www.hankcs.com)')
    88.     draw_point(T)
    89.  
    90.  
    91. currentAxis = plt.gca()
    92. colors =itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
    93.  
    94. #animation function.  this is called sequentially
    95. def animate(i):
    96.     square = history_quare[i]
    97.     currentAxis.add_patch(
    98.         Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1], color=next(colors)))
    99.     return
    100.  
    101. # callthe animator.  blit=true means only re-draw the parts that havechanged.
    102. anim = animation.FuncAnimation(fig, animate, init_func = init, frames = len(history_quare), interval = 1000, repeat = False, blit = False)
    103. plt.show()
    104. anim.save('kdtree_build.gif', fps=2, writer='imagemagick')

      这里写图片描述

      搜索kd树

      上面的代码其实并没有搜索kd树,现在来实现搜索。

      搜索跟二叉树一样来,是一个递归的过程。先找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。以最近邻为例,实现如下(本实现由于测试数据简单,没有做超球体与超立体相交的逻辑):

      1. # -*-coding:utf-8 -*-
      2. #Filename: search_kdtree.py
      3. # Author: hankcs
      4. # Date:2015/2/4 15:01
      5.  
      6. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
      7.  
      8.  
      9. class node:
      10.     def __init__(self, point):
      11.         self.left = None
      12.         self.right = None
      13.         self.point = point
      14.         self.parent = None
      15.         pass
      16.  
      17.     def set_left(self, left):
      18.         if left == None: pass
      19.         left.parent = self
      20.         self.left = left
      21.  
      22.     def set_right(self, right):
      23.         if right == None: pass
      24.         right.parent = self
      25.         self.right = right
      26.  
      27.  
      28. def median(lst):
      29.     m = len(lst) / 2
      30.     return lst[m], m
      31.  
      32.  
      33. def build_kdtree(data, d):
      34.     data = sorted(data, key=lambda x: x[d])
      35.     p, m = median(data)
      36.     tree = node(p)
      37.  
      38.     del data[m]
      39.  
      40.     if m > 0:tree.set_left(build_kdtree(data[:m], not d))
      41.     if len(data) > 1:tree.set_right(build_kdtree(data[m:], not d))
      42.     return tree
      43.  
      44.  
      45. def distance(a, b):
      46.     print a, b
      47.     return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
      48.  
      49.  
      50. def search_kdtree(tree, d, target):
      51.     if target[d] <tree.point[d]:
      52.         if tree.left != None:
      53.             return search_kdtree(tree.left, not d, target)
      54.     else:
      55.         if tree.right != None:
      56.             return search_kdtree(tree.right, not d, target)
      57.  
      58.     def update_best(t, best):
      59.         if t == None: return
      60.         t = t.point
      61.         d = distance(t, target)
      62.         if d < best[1]:
      63.             best[1] = d
      64.             best[0] = t
      65.  
      66.  
      67.     best = [tree.point, 100000.0]
      68.     while (tree.parent != None):
      69.         update_best(tree.parent.left, best)
      70.         update_best(tree.parent.right, best)
      71.         tree = tree.parent
      72.     return best[0]
      73.  
      74.  
      75. kd_tree =build_kdtree(T, 0)
      76. print search_kdtree(kd_tree, 0, [9, 4])

        去掉注释和空白,大概数十行,Python真不愧是可运行的伪码。

        输出:

        1. [8, 1] [9, 4]
        2. [5, 4] [9, 4]
        3. [9, 6] [9, 4]
        4. [9, 6]

          可见对于点[9, 4],在n=6的数据集中,kdtree算法一共只进行了3次计算。

          原文作者:hankcs
          原文地址:http://www.hankcs.com/ml/k-nearest-neighbor-method.html

          评论
          添加红包

          请填写红包祝福语或标题

          红包个数最小为10个

          红包金额最低5元

          当前余额3.43前往充值 >
          需支付:10.00
          成就一亿技术人!
          领取后你会自动成为博主和红包主的粉丝 规则
          hope_wisdom
          发出的红包
          实付
          使用余额支付
          点击重新获取
          扫码支付
          钱包余额 0

          抵扣说明:

          1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
          2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

          余额充值