K近邻算法没有显示的学习过程,由三个基本要素——距离度量,k值的选择和分类决策规则决定。
距离选择可以是欧式距离,Lp距离或Minkowski距离
k值的选择过大会使模型变得简单;过小会导致过拟合,通常取一个比较小的数值采用交叉验证法来取最优。
分类决策规则往往是多数表决(等价于经验风险最小化)。
kd树最近邻搜索
建树过程比较简单,详见文末代码,这里直接从搜索开始。
# 用kd树的最近邻搜索算法,树已构造好
def search(self, node, aim, depth=0):
if node is not None:
n = len(aim) # aim是目标点,计算维度
axis = depth % n
if aim[axis] < node.data[axis]:
self.search(node.lchild, aim, depth+1)
else:
self.search(node.rchild, aim, depth+1)
dis = self.dist(aim, node.data) # 欧式距离
if self.nearest is None or self.nearestDis > dis:
self.nearest = node.data # 保存当前最近点
self.nearestDis = dis # 当前最近距离
# 算法3.3(3)(b) 判断是否需要去另一子结点搜索
if abs(node.data[axis] - aim[axis]) <= self.nearestDis:
if aim[axis] < node.data[axis]:
self.search(node.rchild, aim, depth + 1)
else:
self.search(node.lchild, aim, depth + 1)
其中nearest和nearestDis在初始化类时定义
def __init__(self):
self.kdTree = None
self.nearest = None
self.nearestDis = 0
算法的实现学习了这位大佬:k近邻 kd树的python实现
kd树实现k近邻算法
在上述算法的基础上进行改进:用列表来保存k个最近的点,初始化函数变为
def __init__(self):
self.kdTree = None
self.nearest = []
self.nearestDis = []
算法框架不变,先找到包含目标点的区域然后开始回退。
当列表中的点不足k个时,将当前结点加入nearest列表,同时将当前结点与目标点的距离加入nearestDis列表;另外,当前结点与目标点距离小于nearestDis列表中的最大距离时,对nearest和nearestDis列表进行更新,替换此最远点和对应的距离。
检查是否需要去当前结点的另一子结点搜索,只需将之前的判断半径变为nearest中最远点到目标点的距离就可以。
# 用kd树的k近邻搜索算法,k缺省为1
def search(self, node, aim, k=1, depth=0):
if node is not None:
n = len(aim)
axis = depth % n
if aim[axis] < node.data[axis]:
self.search(node.lchild, aim, k, depth+1)
else:
self.search(node.rchild, aim, k, depth+1)
dis = self.dist(aim, node.data) # 欧式距离
if len(self.nearest) < k: # 若不够k个,加入k近邻列表
self.nearest.append(node.data)
self.nearestDis.append(dis)
elif max(self.nearestDis) > dis: # 当小于k近邻列表中最大值时替换此最大值
maxIndex = self.nearestDis.index(max(self.nearestDis))
self.nearest[maxIndex] = node.data
self.nearestDis[maxIndex] = dis
# 判断是否需要去另一子结点搜索
if abs(node.data[axis] - aim[axis]) <= max(self.nearestDis): # 搜索半径
if aim[axis] < node.data[axis]:
self.search(node.rchild, aim, k, depth + 1)
else:
self.search(node.lchild, aim, k, depth + 1)
k缺省为1;当k大于数据集中所有点时,不报错,输出所有点。
对书上实例运行如下:
data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
x = [5, 3]
kdtree = KdTree() # 定义一个kd树实例
tree = kdtree.create(data) # 构造树
# kdtree.preOrder(tree)
kdtree.search(tree, x, 3) # 搜索
print(kdtree.nearest)
print(kdtree.nearestDis)
结果:
[[2, 3], [5, 4], [7, 2]]
[3.0, 1.0, 2.23606797749979]