第3章 k近邻法
3.1 k k k 近邻法介绍
k k k 近邻法是一种基本分类与回归方法. k k k 近邻法的输入为实例的特征向量,对应于特征空间的点;输出为实例的类别,可以取多类. 统计学习方法一书只讨论分类问题中的 k k k 近邻法.
3.1.1 k k k 近邻算法
k
k
k 近邻法的特殊情况是
k
=
1
k=1
k=1的情形,称为最近邻算法,对于输入的实例点(特征向量)
x
x
x,最近邻法将训练数据集中于
x
x
x 最邻近点的类作为
x
x
x 的类.
3.1.2 k k k 近邻模型
k
k
k 近邻法使用的模型实际上对应于特征空间的划分,模型由三个基本要素决定:距离度量、
k
k
k 值的选择和分类决策规则,接下来会详细展开.
(1) 距离度量
特征空间中两个实例点的距离是两个实例点相似程度的反映, k k k 近邻模型的特征空间一般是 n n n 维实数向量Rn,使用的距离是欧式距离,但也可以是其他距离. L 2 L_2 L2 距离的公式如下: L 2 ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ ) 1 2 L_2(x_i,x_j)=\displaystyle(\sum_{l=1}^{n}\mid x_i^{(l)}-x_j^{(l)}\mid)^\frac{1}{2} L2(xi,xj)=(l=1∑n∣xi(l)−xj(l)∣)21
(2) k k k 值的选择
k k k 值的选择会对 k k k 近邻法的结果产生重大影响. 如果选择较小的 k k k 值,就相当于用较小的邻域中的训练实例进行预测,"学习"的近似误差会减小,只有与输入实例较近的 (相似的) 训练实例才会对预测结果起作用,但估计误差会增大,预测结果会对近邻的实例点非常敏感;如果选择较大的 k k k 值,则两种误差的结果正相反;如果 k = N k=N k=N,那么无论输入实例是什么,都将简单地预测它属于在训练实例中最多的类.
在应用中, k k k 值一般取一个比较小的数值,通常采用交叉验证法来选取最优的值.
(3) 分类决策规则
k
k
k 近邻法中的分类决策规则往往是多数表决,即由输入实例的
k
k
k 个邻近的训练实例中的多数类决定输入实例的类.
3.1.3 k k k 近邻法的实现: k d kd kd 树
实现 k k k 近邻法时,主要考虑的问题是如何对训练数据进行快速 k k k 近邻搜索,这点在特征空间的维数大及训练数据容量大时尤其必要. 为了提高 k k k 近邻搜索的效率,可以考虑使用特殊的结构存储训练数据, k d kd kd 树方法就是其中一种.
(1) 构造 k d kd kd 树
k
d
kd
kd 树是一种对
k
k
k 维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,它是二叉树,表示对
k
k
k 维空间的一个划分. 构造
k
d
kd
kd 树相当于不断地用垂直于坐标轴的超平面将
k
k
k 维空间划分,构成一系列的
k
k
k 维超矩形区域. 相信直接用算法解释会更加直观,此处列出书中定义的构造平衡
k
d
kd
kd 树算法
(2) 搜索
k
d
kd
kd 树
所谓搜索 k d kd kd 树就是利用 k d kd kd 树进行 k k k 近邻搜索,这样就可以省去对大部分数据点的搜索,从而减少搜索的计算量,接下来就以最近邻为例加以叙述并列出算法.
给定一个目标点,搜索其最近邻. 首先找到包含目标点的叶结点;然后从该叶结点出发,依次回退到父结点;不断查找与目标点最邻近的结点,当确定不可能存在更近的结点时终止. 这里停止搜索的条件是:如果父结点的另一子结点的超矩形区域与超球体不相交,或不存在比当前最近点最近的点,则停止搜索. 最后列出书中算法:
3.2 k k k 近邻法 Python 实现
之前说过,文章的重点是放在算法的编程语言实现上,所以接下来是重头戏,话不多说,直接上代码,我尽量多做注释使代码简单易懂.
3.2.1 例3.1–不同 p p p 值下寻找目标点的最近邻点
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
def compute_distance(point1,point2,p):
#point1与point2的维度一致,p是大于0的整数
dist = 0
point_shape = point2.shape
for i in range(point_shape[0]):
dist +=(abs(point1[i]-point2[i]))**p
dist = dist**(1/p)
return dist
class Solution(object):
def find_nearest(self,vector,point,p):
'''
vector中的元素与point长度一致
input:数据集vector,确定距离种类的p值与目标点point
return:目标点的最近邻点
'''
vector.remove(point)
vector = np.array(vector)
point = np.array(point)
distance = compute_distance(vector[0],point,p)
print("点 %s 与 目标点 %s 之间的Lp距离为:%s" % (str(vector[0]),str(point),str(distance)))
index = [0]
for i in range(1,len(vector)):
tmp = compute_distance(vector[i],point,p)
print("点 %s 与 目标点 %s 之间的Lp距离为:%s" % (str(vector[i]),str(point),str(tmp)))
if tmp <= distance:
distance = tmp
index.append(i)
print("目标点 %s 的最近邻点为:%s" % (str(point),str(vector[index[-1]])))
return distance,vector[index[-1]]
if __name__ == '__main__':
x = [[1,1],[5,1],[4,4]]
point = [1,1]
p = 1
solution = Solution()
distance,goalpoint = solution.find_nearest(x,point,p)
p=1时的结果如下:
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/44872548ad38e91f1cd5b0f002923aa8.png)
3.2.2 例3.2–构造平衡 k d kd kd 树
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
class QuickSort(object):
# 首先在构造kd树的时需要寻找中位数,因此用快速排序来获取一个list中的中位数
def __init__(self, low, high, array):
self._array = array
self._low = low
self._high = high
self._medium = (low+high+1)//2 # python3中的整除
def get_medium_num(self):
return self.quick_sort_for_medium(self._low, self._high,
self._medium, self._array)
def quick_sort_for_medium(self, low, high, medium, array): #用快速排序来获取中位数
if high == low:
return array[low] # find medium
if high > low:
index, partition = self.sort_partition(low, high, array);
#print array[low:index], partition, array[index+1:high+1]
if index == medium:
return partition
if index > medium:
return self.quick_sort_for_medium(low, index-1, medium, array)
else:
return self.quick_sort_for_medium(index+1, high, medium, array)
def quick_sort(self, low, high, array): #正常的快排
if high > low:
index, partition = self.sort_partition(low, high, array);
#print array[low:index], partition, array[index+1:high+1]
self.quick_sort(low, index-1, array)
self.quick_sort(index+1, high, array)
def sort_partition(self, low, high, array): # 用第一个数将数组里面的数分成两部分
index_i = low
index_j = high
partition = array[low]
while index_i < index_j:
while (index_i < index_j) and (array[index_j] >= partition):
index_j -= 1
if index_i < index_j:
array[index_i] = array[index_j]
index_i += 1
while (index_i < index_j) and (array[index_i] < partition):
index_i += 1
if index_i < index_j:
array[index_j] = array[index_i]
index_j -= 1
array[index_i] = partition
return index_i, partition
class KDTree(object):
def __init__(self, input_x):
self._input_x = np.array(input_x)
(data_num, axes_num) = np.shape(self._input_x)
self._data_num = data_num
self._axes_num = axes_num
self._nearest = None #用来存储最近的节点
return
def construct_kd_tree(self):
return self._construct_kd_tree(0, 0, self._input_x)
def _construct_kd_tree(self, depth, axes, data):
if not data.any():
return None
axes_data = data[:, axes].copy()
qs = QuickSort(0, axes_data.shape[0]-1, axes_data)
medium = qs.get_medium_num() #找到轴的中位数
data_list = []
left_data = []
right_data = []
data_range = range(np.shape(data)[0])
for i in data_range: # 跟中位数相比较
if data[i][axes] == medium: #相等
data_list.append(data[i])
elif data[i][axes] < medium:
left_data.append(data[i])
else:
right_data.append(data[i])
left_data = np.array(left_data)
right_data = np.array(right_data)
left = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, left_data)
right = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, right_data)
#[树的深度,轴,中位数,该节点的数据,左子树,右子树]
root = [depth, axes, medium, data_list, left, right]
return root
def print_kd_tree(self, root): #打印kd树
if root:
[depth, axes, medium, data_list, left, right] = root
print('{} {}'.format(' '*depth, data_list[0]))
if root[4]:
self.print_kd_tree(root[4])
if root[5]:
self.print_kd_tree(root[5])
if __name__ == '__main__':
input_x = [[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]
kd = KDTree(input_x)
tree = kd.construct_kd_tree()
kd.print_kd_tree(tree)
得到的运行结果如下, k d kd kd 树的层次也很清晰.
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/1b225713393c462376dd0d2c86109220.png)
3.2.3 例3.3–利用搜索 k d kd kd 树求目标点的最近邻
因为例题将点抽象化了,故该例使用了例3.2的数据,并使用了sklearn包中带的KDTree函数.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from sklearn.neighbors import KDTree
np.random.seed(0)
# 随机产生150个二维数据
points = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
tree = KDTree(points)
point = np.array([4,5])
# k近邻法搜索,k=1时为最近邻
dists, indices = tree.query([point], k=1)
# q指定半径搜索
r = 2
indices = tree.query_radius([point], r)
fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.add_patch(Circle(point, r, color='g', fill=False))
X, Y = [p[0] for p in points], [p[1] for p in points]
plt.scatter(X, Y)
plt.scatter([point[0]], [point[1]], c='r')
plt.show()
搜索结果如下图所示: