本文对应《统计学习方法》第3章,用数十行代码实现KNN的kd树构建与搜索算法,并用matplotlib可视化了动画观赏。
k近邻算法
给定一个训练数据集,对新的输入实例,在训练数据集中找到跟它最近的k个实例,根据这k个实例的类判断它自己的类(一般采用多数表决的方法)。
k近邻模型
模型有3个要素——距离度量方法、k值的选择和分类决策规则。
模型
当3要素确定的时候,对任何实例(训练或输入),它所属的类都是确定的,相当于将特征空间分为一些子空间。
距离度量
对n维实数向量空间Rn,经常用Lp距离或曼哈顿Minkowski距离。
Lp距离定义如下:
当p=2时,称为欧氏距离:
当p=1时,称为曼哈顿距离:
当p=∞,它是各个坐标距离的最大值,即:
用图表示如下:
k值的选择
k较小,容易被噪声影响,发生过拟合。
k较大,较远的训练实例也会对预测起作用,容易发生错误。
分类决策规则
使用0-1损失函数衡量,那么误分类率是:
Nk是近邻集合,要使左边最小,右边的必须最大,所以多数表决=经验最小化。
k近邻法的实现:kd树
算法核心在于怎么快速搜索k个近邻出来,朴素做法是线性扫描,不可取,这里介绍的方法是kd树。
构造kd树
对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:
找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S',++i,node=current),直到不可再分。
例子
Python代码
短短几行即可搞定:
- T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
- class node:
- def __init__(self, point):
- self.left = None
- self.right = None
- self.point = point
- pass
- def median(lst):
- m = len(lst) / 2
- return lst[m], m
- def build_kdtree(data, d):
- data = sorted(data, key=lambda x: x[d])
- p, m = median(data)
- tree = node(p)
- del data[m]
- print data, p
- if m > 0: tree.left = build_kdtree(data[:m], not d)
- if len(data) > 1: tree.right = build_kdtree(data[m:], not d)
- return tree
- kd_tree = build_kdtree(T, 0)
- print kd_tree
可视化
可视化的话则要费点功夫保存中间结果,并恰当地展示出来
- # -*- coding:utf-8 -*-
- # Filename: kdtree.py
- # Author:hankcs
- # Date: 2015/2/4 15:01
- import copy
- import itertools
- from matplotlib import pyplot as plt
- from matplotlib.patches import Rectangle
- from matplotlib import animation
- T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
- def draw_point(data):
- X, Y = [], []
- for p in data:
- X.append(p[0])
- Y.append(p[1])
- plt.plot(X, Y, 'bo')
- def draw_line(xy_list):
- for xy in xy_list:
- x, y = xy
- plt.plot(x, y, 'g', lw=2)
- def draw_square(square_list):
- currentAxis = plt.gca()
- colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
- for square in square_list:
- currentAxis.add_patch(
- Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
- color=next(colors)))
- def median(lst):
- m = len(lst) / 2
- return lst[m], m
- history_quare = []
- def build_kdtree(data, d, square):
- history_quare.append(square)
- data = sorted(data, key=lambda x: x[d])
- p, m = median(data)
- del data[m]
- print data, p
- if m >= 0:
- sub_square = copy.deepcopy(square)
- if d == 0:
- sub_square[1][0] = p[0]
- else:
- sub_square[1][1] = p[1]
- history_quare.append(sub_square)
- if m > 0: build_kdtree(data[:m], not d, sub_square)
- if len(data) > 1:
- sub_square = copy.deepcopy(square)
- if d == 0:
- sub_square[0][0] = p[0]
- else:
- sub_square[0][1] = p[1]
- build_kdtree(data[m:], not d, sub_square)
- build_kdtree(T, 0, [[0, 0], [10, 10]])
- print history_quare
- # draw an animation to show how it works, the data comes from history
- # first set up the figure, the axis, and the plot element we want to animate
- fig = plt.figure()
- ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
- line, = ax.plot([], [], 'g', lw=2)
- label = ax.text([], [], '')
- # initialization function: plot the background of each frame
- def init():
- plt.axis([0, 10, 0, 10])
- plt.grid(True)
- plt.xlabel('x_1')
- plt.ylabel('x_2')
- plt.title('build kd tree (www.hankcs.com)')
- draw_point(T)
- currentAxis = plt.gca()
- colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
- # animation function. this is called sequentially
- def animate(i):
- square = history_quare[i]
- currentAxis.add_patch(
- Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
- color=next(colors)))
- return
- # call the animator. blit=true means only re-draw the parts that have changed.
- anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
- blit=False)
- plt.show()
- anim.save('kdtree_build.gif', fps=2, writer='imagemagick')
搜索kd树
上面的代码其实并没有搜索kd树,现在来实现搜索。
搜索跟二叉树一样来,是一个递归的过程。先找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。以最近邻为例,实现如下(本实现由于测试数据简单,没有做超球体与超立体相交的逻辑):
- # -*- coding:utf-8 -*-
- # Filename: search_kdtree.py
- # Author:hankcs
- # Date: 2015/2/4 15:01
- T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
- class node:
- def __init__(self, point):
- self.left = None
- self.right = None
- self.point = point
- self.parent = None
- pass
- def set_left(self, left):
- if left == None: pass
- left.parent = self
- self.left = left
- def set_right(self, right):
- if right == None: pass
- right.parent = self
- self.right = right
- def median(lst):
- m = len(lst) / 2
- return lst[m], m
- def build_kdtree(data, d):
- data = sorted(data, key=lambda x: x[d])
- p, m = median(data)
- tree = node(p)
- del data[m]
- if m > 0: tree.set_left(build_kdtree(data[:m], not d))
- if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
- return tree
- def distance(a, b):
- print a, b
- return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
- def search_kdtree(tree, d, target):
- if target[d] < tree.point[d]:
- if tree.left != None:
- return search_kdtree(tree.left, not d, target)
- else:
- if tree.right != None:
- return search_kdtree(tree.right, not d, target)
- def update_best(t, best):
- if t == None: return
- t = t.point
- d = distance(t, target)
- if d < best[1]:
- best[1] = d
- best[0] = t
- best = [tree.point, 100000.0]
- while (tree.parent != None):
- update_best(tree.parent.left, best)
- update_best(tree.parent.right, best)
- tree = tree.parent
- return best[0]
- kd_tree = build_kdtree(T, 0)
- print search_kdtree(kd_tree, 0, [9, 4])
去掉注释和空白,大概数十行,Python真不愧是可运行的伪码。
输出:
- [8, 1] [9, 4]
- [5, 4] [9, 4]
- [9, 6] [9, 4]
- [9, 6]
可见对于点[9, 4],在n=6的数据集中,kdtree算法一共只进行了3次计算。