1.概念与目的
k近邻法(k-Nearest Neighbor,简称kNN)学习是一种常用的监督式学习方法。
给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练样本,然后基于这k个“邻居”的信息来进行预测。通常,在分类任务中可使用“投票法”,即选择这k个样本中出现最多的类别标记作为预测结果;在回归任务中可以使用“平均法”,即将这k个样本的实值输出标记的平均值作为预测结果;还可以基于距离远近进行加权平均或加权投票,距离越近的样本权重越大。
**k近邻有个明显的不同之处:**它似乎没有显示的训练过程。事实上,它是“懒惰学习”(lazy learning)的著名代表,此类学习技术在训练阶段仅仅是把样本保存起来,训练时间开销为0,待收到测试样本后再进行处理,这种称为“急切学习”。
****KNN的三要素:****k值的选择,距离度量及分类决策规则.当k=1时称为最近邻算法.主要核心思想是“物以类聚”,看训练集中离该输入最近的实例多数属于什么类别。
2.模型
当训练集,距离度量,k值以及分类决策规则确定后,特征空间已经根据这些要素被划分为一些子空间,且子空间里每个点所属的类也已被确定.
3.策略
(1)距离:
特征空间中两个实例点的距离是相似程度的反映,k近邻算法一般使用欧氏距离,也可以使用更一般的Lp距离或Minkowski距离.设特征空间X是n维实数向量空间Rn,x(i),x(j)∈X,x=(x0,x1,x2,…,xn)T,x(i),x(j)的Lp距离定义为:
这里p⩾1。当p=2时,称为欧式距离(Euclidean distance),即
当p=1时,称为曼哈顿距离(Manhanttan distance),即
当p=∞时,它是各个坐标距离的最大值,即:
(2)k值:
k值较小时,整体模型变得复杂,容易发生过拟合.k值较大时,整体模型变得简单.在应用中k一般取较小的值,通过交叉验证法选取最优的k.
(3)分类决策规则
k近邻中的分类决策规则往往是多数表决,多数表决规则等价于经验风险最小化.
4.算法
目标:根据给定的距离度量,在训练集中找出与x最邻近的k个点,根据分类规则决定x的类别ykd树算法
(1)描述:
kd树就是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,它是一个二叉树,表示对k维空间的一个划分.构造kd树相当于不断用垂直于坐标轴的超平面将k维空间划分,构造一列的k维超矩形区域,而kd树的每一个节点对应于一个k维超矩形区域。kd树更适用于训练实例数远大于空间维数时的k近邻搜索.
(2)构造平衡kd树算法:
可以通过如下递归实现:在超矩形区域上选择一个坐标轴和此坐标轴上的一个切分点,确定一个超平面,该超平面将当前超矩形区域切分为两个子区域.在子区域上重复切分直到子区域内没有实例时终止.通常依次选择坐标轴和选定坐标轴上的中位数点为切分点,这样可以得到平衡kd树.
(3)kd树的最近邻搜索:
从根节点出发,若目标点x当前维的坐标小于切分点的坐标则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止.以此叶结点为"当前最近点",递归地向上回退,在每个结点:(a)如果该结点比当前最近点距离目标点更近,则以该结点为"当前最近点"(b)"当前最近点"一定存在于该结点一个子结点对应的区域,检查该结点的另一子结点对应的区域是否与以目标点为球心,以目标点与"当前最近点"间的距离为半径的超球体相交.如果相交,移动到另一个子结点,如果不相交,向上回退.持续这个过程直到回退到根结点,最后的"当前最近点"即为最近邻点.
5.kd树的代码实现
(1)首先在构造kd树的时候需要寻找中位数,因此用快速排序来获取一个list中的中位数。
import matplotlib.pyplot as plt
import numpy as np
class QuickSort(object):
"Quick Sort to get medium number"
def __init__(self, low, high, array):
self._array = array
self._low = low
self._high = high
self._medium = (low+high+1)//2 # python3中的整除
def get_medium_num(self):
return self.quick_sort_for_medium(self._low, self._high,
self._medium, self._array)
def quick_sort_for_medium(self, low, high, medium, array): #用快速排序来获取中位数
if high == low:
return array[low] # find medium
if high > low:
index, partition = self.sort_partition(low, high, array);
#print array[low:index], partition, array[index+1:high+1]
if index == medium:
return partition
if index > medium:
return self.quick_sort_for_medium(low, index-1, medium, array)
else:
return self.quick_sort_for_medium(index+1, high, medium, array)
def quick_sort(self, low, high, array): #正常的快排
if high > low:
index, partition = self.sort_partition(low, high, array);
#print array[low:index], partition, array[index+1:high+1]
self.quick_sort(low, index-1, array)
self.quick_sort(index+1, high, array)
def sort_partition(self, low, high, array): # 用第一个数将数组里面的数分成两部分
index_i = low
index_j = high
partition = array[low]
while index_i < index_j:
while (index_i < index_j) and (array[index_j] >= partition):
index_j -= 1
if index_i < index_j:
array[index_i] = array[index_j]
index_i += 1
while (index_i < index_j) and (array[index_i] < partition):
index_i += 1
if index_i < index_j:
array[index_j] = array[index_i]
index_j -= 1
array[index_i] = partition
return index_i, partition
(2)构造kd树
class KDTree(object):
def __init__(self, input_x, input_y):
self._input_x = np.array(input_x)
self._input_y = np.array(input_y)
(data_num, axes_num) = np.shape(self._input_x)
self._data_num = data_num
self._axes_num = axes_num
self._nearest = None #用来存储最近的节点
return
def construct_kd_tree(self):
return self._construct_kd_tree(0, 0, self._input_x)
def _construct_kd_tree(self, depth, axes, data):
if not data.any():
return None
axes_data = data[:, axes].copy()
qs = QuickSort(0, axes_data.shape[0]-1, axes_data)
medium = qs.get_medium_num() #找到轴的中位数
data_list = []
left_data = []
right_data = []
data_range = range(np.shape(data)[0])
for i in data_range: # 跟中位数相比较
if data[i][axes] == medium: #相等
data_list.append(data[i])
elif data[i][axes] < medium:
left_data.append(data[i])
else:
right_data.append(data[i])
left_data = np.array(left_data)
right_data = np.array(right_data)
left = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, left_data)
right = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, right_data)
#[树的深度,轴,中位数,该节点的数据,左子树,右子树]
root = [depth, axes, medium, data_list, left, right]
return root
def print_kd_tree(self, root): #打印kd树
if root:
[depth, axes, medium, data_list, left, right] = root
print('{} {}'.format(' '*depth, data_list[0]))
if root[4]:
self.print_kd_tree(root[4])
if root[5]:
self.print_kd_tree(root[5])
//测试代码
input_x = [[2,3], [6,4], [9,6], [4,7], [8,1], [7,2]]
input_y = [1, 1, 1, 1, 1, 1]
kd = KDTree(input_x, input_y)
tree = kd.construct_kd_tree()
kd.print_kd_tree(tree)
#得到结果:
[7 2]
[6 4]
[2 3]
[4 7]
[9 6]
[8 1]
(3)搜索kd树
在类中继续添加如下函数,基本的思路是将路径上的节点依次入栈,再逐个出栈。
def _get_distance(self, x1, x2): #计算两个向量之间的距离
x = x1-x2
return np.sqrt(np.inner(x, x))
def _search_leaf(self, stack, tree, target): #以tree为根节点,一直搜索到叶节点,并添加到stack中
travel_tree = tree
while travel_tree:
[depth, axes, medium, data_list, left, right] = travel_tree
if target[axes] >= medium:
next_node = right
next_direction = 'right' # 记录被访问过的子树的方向
elif target[axes] < medium:
next_node = left
next_direction = 'left' # 记录被访问过的子树的方向
stack.append([travel_tree, next_direction]) #保存方向,用来记录哪个子树被访问过
travel_tree = next_node
def _check_nearest(self, current, target): # 判断当前节点跟目标的距离
d = self._get_distance(current, target)
if self._nearest:
[node, distance] = self._nearest
if d < distance:
self._nearest = [current, d]
else:
self._nearest = [current, d]
def search_kd_tree(self, tree, target): #搜索kd树
stack = []
self._search_leaf(stack, tree, target) # 一直搜索到叶节点,并将路径入栈
self._nearest = []
while stack:
[[depth, axes, medium, data_list, left, right], next_direction] = stack.pop() #出栈
[data] = data_list
self._check_nearest(data, target) #检查当前节点的距离
if left is None and right is None: #如果当前节点为叶节点,继续下一个循环
continue
[node, distance] = self._nearest
if abs(data[axes] - node[axes]) < distance: #<*> 当前节点的轴经过圆
if next_direction == 'right': # 判断哪个方向被访问过,转向相反方向
try_node = left
else:
try_node = right
self._search_leaf(stack, try_node, target) #往相反的方向搜索叶节点
print(self._nearest)
//测试代码
kd.search_kd_tree(tree, [7.1, 4.1])
> [array([6, 4]), 1.1045361017187258]
kd.search_kd_tree(tree, [9, 2])
> [array([8, 1]), 1.4142135623730951]
kd.search_kd_tree(tree, [6, 2])
> [array([7, 2]), 1.0]
(4)寻找k个最近节点
如果要寻找k个最近节点,则需要保存k个元素的数组,并在函数_check_nearest中与k个元素做比较,然后在标记<*>的地方跟k个元素的最大值比较。其他代码略。
def _check_nearest(self, current, target, k):
d = self._get_distance(current, target)
#print current, d
l = len(self._nearest)
if l < k:
self._nearest.append([current, d])
else:
farthest = self._get_farthest()[1]
if farthest > d:
# 将最远的节点移除
new_nearest = [i for i in self._nearest if i[1] [[array([7, 2]), 2.1023796041628633], [array([6, 4]), 1.1045361017187258]]
kd.search_kd_tree(tree, [9, 2], k=2)
> [[array([8, 1]), 1.4142135623730951], [array([7, 2]), 2.0]]
kd.search_kd_tree(tree, [6, 2], k=2)
> [[array([6, 4]), 2.0], [array([7, 2]), 1.0]]
参考博客解释:
https://blog.csdn.net/baimafujinji/article/details/52928203