统计学习方法---k近邻法

本文对应《统计学习方法》第3章,用数十行代码实现KNN的kd树构建与搜索算法,并用matplotlib可视化了动画观赏。

k近邻算法

给定一个训练数据集,对新的输入实例,在训练数据集中找到跟它最近的k个实例,根据这k个实例的类判断它自己的类(一般采用多数表决的方法)。

k近邻模型

模型有3个要素——距离度量方法、k值的选择和分类决策规则。

模型

当3要素确定的时候,对任何实例(训练或输入),它所属的类都是确定的,相当于将特征空间分为一些子空间。

距离度量

对n维实数向量空间Rn,经常用Lp距离或曼哈顿Minkowski距离。

Lp距离定义如下:

当p=2时,称为欧氏距离:

当p=1时,称为曼哈顿距离:

当p=∞,它是各个坐标距离的最大值,即:

用图表示如下:

k值的选择

k较小,容易被噪声影响,发生过拟合。

k较大,较远的训练实例也会对预测起作用,容易发生错误。

分类决策规则

使用0-1损失函数衡量,那么误分类率是:

Nk是近邻集合,要使左边最小,右边的必须最大,所以多数表决=经验最小化。

k近邻法的实现:kd树

算法核心在于怎么快速搜索k个近邻出来,朴素做法是线性扫描,不可取,这里介绍的方法是kd树。

构造kd树

对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:

找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S',++i,node=current),直到不可再分。

例子

Python代码

短短几行即可搞定:

  1. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  2.  
  3. class node:
  4.     def __init__(self, point):
  5.         self.left = None
  6.         self.right = None
  7.         self.point = point
  8.         pass
  9.     
  10. def median(lst):
  11.     m = len(lst) / 2
  12.     return lst[m], m
  13.  
  14. def build_kdtree(data, d):
  15.     data = sorted(data, key=lambda x: x[d])
  16.     p, m = median(data)
  17.     tree = node(p)
  18.  
  19.     del data[m]
  20.     print data, p
  21.  
  22.     if m > 0: tree.left = build_kdtree(data[:m], not d)
  23.     if len(data) > 1: tree.right = build_kdtree(data[m:], not d)
  24.     return tree
  25.  
  26. kd_tree = build_kdtree(T, 0)
  27. print kd_tree


可视化

可视化的话则要费点功夫保存中间结果,并恰当地展示出来

  1. # -*- coding:utf-8 -*-
  2. # Filename: kdtree.py
  3. # Authorhankcs
  4. # Date: 2015/2/4 15:01
  5. import copy
  6. import itertools
  7. from matplotlib import pyplot as plt
  8. from matplotlib.patches import Rectangle
  9. from matplotlib import animation
  10.  
  11. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  12.  
  13.  
  14. def draw_point(data):
  15.     X, Y = [], []
  16.     for p in data:
  17.         X.append(p[0])
  18.         Y.append(p[1])
  19.     plt.plot(X, Y, 'bo')
  20.  
  21.  
  22. def draw_line(xy_list):
  23.     for xy in xy_list:
  24.         x, y = xy
  25.         plt.plot(x, y, 'g', lw=2)
  26.  
  27.  
  28. def draw_square(square_list):
  29.     currentAxis = plt.gca()
  30.     colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
  31.     for square in square_list:
  32.         currentAxis.add_patch(
  33.             Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
  34.                       color=next(colors)))
  35.  
  36.  
  37. def median(lst):
  38.     m = len(lst) / 2
  39.     return lst[m], m
  40.  
  41.  
  42. history_quare = []
  43.  
  44. def build_kdtree(data, d, square):
  45.     history_quare.append(square)
  46.     data = sorted(data, key=lambda x: x[d])
  47.     p, m = median(data)
  48.  
  49.     del data[m]
  50.     print data, p
  51.  
  52.     if m >= 0:
  53.         sub_square = copy.deepcopy(square)
  54.         if d == 0:
  55.             sub_square[1][0] = p[0]
  56.         else:
  57.             sub_square[1][1] = p[1]
  58.         history_quare.append(sub_square)
  59.         if m > 0: build_kdtree(data[:m], not d, sub_square)
  60.     if len(data) > 1:
  61.         sub_square = copy.deepcopy(square)
  62.         if d == 0:
  63.             sub_square[0][0] = p[0]
  64.         else:
  65.             sub_square[0][1] = p[1]
  66.         build_kdtree(data[m:], not d, sub_square)
  67.  
  68.  
  69. build_kdtree(T, 0, [[0, 0], [10, 10]])
  70. print history_quare
  71.  
  72.  
  73. # draw an animation to show how it works, the data comes from history
  74. # first set up the figure, the axis, and the plot element we want to animate
  75. fig = plt.figure()
  76. ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
  77. line, = ax.plot([], [], 'g', lw=2)
  78. label = ax.text([], [], '')
  79.  
  80. # initialization function: plot the background of each frame
  81. def init():
  82.     plt.axis([0, 10, 0, 10])
  83.     plt.grid(True)
  84.     plt.xlabel('x_1')
  85.     plt.ylabel('x_2')
  86.     plt.title('build kd tree (www.hankcs.com)')
  87.     draw_point(T)
  88.  
  89.  
  90. currentAxis = plt.gca()
  91. colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
  92.  
  93. # animation function.  this is called sequentially
  94. def animate(i):
  95.     square = history_quare[i]
  96.     currentAxis.add_patch(
  97.         Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
  98.                   color=next(colors)))
  99.     return
  100.  
  101. # call the animator.  blit=true means only re-draw the parts that have changed.
  102. anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
  103.                                blit=False)
  104. plt.show()
  105. anim.save('kdtree_build.gif', fps=2, writer='imagemagick')

搜索kd树

上面的代码其实并没有搜索kd树,现在来实现搜索。

搜索跟二叉树一样来,是一个递归的过程。先找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。以最近邻为例,实现如下(本实现由于测试数据简单,没有做超球体与超立体相交的逻辑):

  1. # -*- coding:utf-8 -*-
  2. # Filename: search_kdtree.py
  3. # Authorhankcs
  4. # Date: 2015/2/4 15:01
  5.  
  6. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  7.  
  8.  
  9. class node:
  10.     def __init__(self, point):
  11.         self.left = None
  12.         self.right = None
  13.         self.point = point
  14.         self.parent = None
  15.         pass
  16.  
  17.     def set_left(self, left):
  18.         if left == None: pass
  19.         left.parent = self
  20.         self.left = left
  21.  
  22.     def set_right(self, right):
  23.         if right == None: pass
  24.         right.parent = self
  25.         self.right = right
  26.  
  27.  
  28. def median(lst):
  29.     m = len(lst) / 2
  30.     return lst[m], m
  31.  
  32.  
  33. def build_kdtree(data, d):
  34.     data = sorted(data, key=lambda x: x[d])
  35.     p, m = median(data)
  36.     tree = node(p)
  37.  
  38.     del data[m]
  39.  
  40.     if m > 0: tree.set_left(build_kdtree(data[:m], not d))
  41.     if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
  42.     return tree
  43.  
  44.  
  45. def distance(a, b):
  46.     print a, b
  47.     return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
  48.  
  49.  
  50. def search_kdtree(tree, d, target):
  51.     if target[d] < tree.point[d]:
  52.         if tree.left != None:
  53.             return search_kdtree(tree.left, not d, target)
  54.     else:
  55.         if tree.right != None:
  56.             return search_kdtree(tree.right, not d, target)
  57.  
  58.     def update_best(t, best):
  59.         if t == None: return
  60.         t = t.point
  61.         d = distance(t, target)
  62.         if d < best[1]:
  63.             best[1] = d
  64.             best[0] = t
  65.  
  66.     best = [tree.point, 100000.0]
  67.     while (tree.parent != None):
  68.         update_best(tree.parent.left, best)
  69.         update_best(tree.parent.right, best)
  70.         tree = tree.parent
  71.     return best[0]
  72.  
  73.  
  74. kd_tree = build_kdtree(T, 0)
  75. print search_kdtree(kd_tree, 0, [9, 4])

去掉注释和空白,大概数十行,Python真不愧是可运行的伪码。

输出:

  1. [8, 1] [9, 4]
  2. [5, 4] [9, 4]
  3. [9, 6] [9, 4]
  4. [9, 6]

可见对于点[9, 4],在n=6的数据集中,kdtree算法一共只进行了3次计算。



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值