根据模型的数学原理进行简单的代码自我复现以及使用测试,仅作自我学习用。模型原理此处不作过多赘述,仅罗列自己将要使用到的部分公式。
如文中或代码有错误或是不足之处,还望能不吝指正。
某些机器学习模型,如KNN,中,需要在n维空间上计算距离,找到训练样本中距离自身最近的那一个点。如果直接计算距离,就需要O(n²)的时间复杂度。故而需要引入KD树作为索引,以此搜索最近距离的点。
KD树的大致原理,就是在每一层中,根据剩余数据集中方差最大的特征排序,选取正中间的点作为节点值,将左右两边的值分别构建左右子树。
在搜索时,
1 首先从根节点开始,根据当前节点的分割特征判断像左还是向右移,直至达到叶子节点。
2 将叶子节点作为“最近点”,并从叶子节点开始向前回溯,计算到此节点距离是否更小。是的话替代
3 在回溯过程的同时,还需要与当前节点的父节点进行比较:如果点到父节点对应所在的(超)平面(也就是父节点分割依据的那个特征所在平面)距离小于到当前节点的2点的距离, 那么就代表目标点其兄弟节点的距离有可能更短,应当从其兄弟节点处重新执行1~3步。
这里其实很好理解,因为目标点到平面的距离是垂直的最短距离,如果点到当前节点的距离比这个距离小,那么在平面上的其他节点也会小于这个距离。但是点到当前节点的距离比垂直距离更大时,那么兄弟节点就有可能成为那个“距离更小的节点”。
import numpy as np
from collections import deque
class Node:
def __init__(self,value=None,split=None,left=None,right=None,father=None):
self.value = value
self.split = split
self.left = left
self.right = right
self.father = father
class KDTree:
def __init__(self,x=None):
if x is not None:
self.root = self.buildtree(x)
else:
self.root = Node()
def get_median(self,sub_x):
x = list(sub_x)
length = len(x)
x_order = sorted(x)
return x_order[length//2],x.index(x_order[length//2])
def buildtree(self,x):
if len(x)== 0:
return None
#寻找方差最大的那个方向
max_std = 0
max_idx = 0
for i in range(x.shape[1]):
std = np.std(x[:,i])
if std>max_std:
max_idx = i
max_std = std
#找到中点
v,v_idx=self.get_median(x[:,max_idx])
#根据中点值分割
cur = Node(value=x[v_idx,:],split=max_idx)
left_idx = []
right_idx = []
for i in range(len(x)):
if x[i,max_idx]>v:
right_idx.append(i)
elif x[i,max_idx]<v or (x[i,max_idx]==v and i != v_idx):
left_idx.append(i)
cur.left = self.buildtree(x[left_idx,:])
if cur.left is not None:
cur.left.father = cur
cur.right = self.buildtree(x[right_idx,:])
if cur.right is not None:
cur.right.father = cur
return cur
def dist(self,point1,point2):
if hasattr(point1,'value'):
point1 = point1.value
if hasattr(point2,'value'):
point2 = point2.value
if len(point1) != len(point2):
raise ValueError("2点维度不同,不可计算距离")
return (sum([(point1[i]-point2[i])**2 for i in range(len(point1))]))**(1/2)
def brother(self,node):
if node.father is None:
return None
else:
if node.father.left == node:
return node.father.right
else:
return node.father.left
def get_leaf(self,x,node):
#找到叶子节点
while node.left is not None or node.right is not None:
if node.left is None:
return node.right
elif node.right is None:
return node.left
else:
if x[node.split] < node.value[node.split]:
node = node.left
else:
node = node.right
return node
def search_nearest(self,x):
distance = float("inf")
nearest_node = self.get_leaf(x,self.root)
que = [(self.root,nearest_node)]
que = deque(que)
while que:
root,cur = que.popleft()
while cur is not root:
dist = self.dist(x,cur.value)
if dist<distance:
distance = dist
nearest_node = cur
if self.brother(cur) is not None:
father_split = cur.father.split
new_dist = abs(x[father_split]-cur.father.value[father_split])
if new_dist<distance:
nearest_node = self.get_leaf(x,self.brother(cur))
que.append((self.brother(cur),nearest_node))
cur = cur.father
return nearest_node
而对于“在目标点的周围搜索K个最邻近的点”这一问题,应该将逻辑替换为“先保存k个节点,等到遇到距离更小的节点再替换保存的距离最大的节点”。很可惜我只找到理论部分,而sklearn的代码是pyd,我也没有找到反汇编(或是反编译?我个人缺乏此处的知识),自己写了部分代码,没有经过大量实验,故而只能作为参考,并不能作为真正的使用代码。
def search_nearest_k(self,x,k):
if k == 0:
return None
last_node = self.get_leaf(x,self.root)
que = [(self.root,last_node.father)]
que = deque(que)
distance = self.dist(x,last_node)
selected_nodes = [(last_node,distance)]
while que:
root,cur = que.popleft()
while cur is not root:
dist = self.dist(x,cur.value)
if len(selected_nodes)<k:
if dist>=distance:
selected_nodes.append((cur,dist))
else:
selected_nodes = [(cur,dist)]+selected_nodes
elif dist<selected_nodes[-1][1]:
#遇到距离更小的点,替换原来距离最大的点
selected_nodes.pop()
selected_nodes.append((cur,dist))
selected_nodes.sort(key = lambda x:x[1])
if self.brother(cur) is not None:
father_split = cur.father.split
new_dist = abs(x[father_split]-cur.father.value[father_split])
if new_dist<selected_nodes[-1][1] or len(selected_nodes)<k:
last_node = self.get_leaf(x,self.brother(cur))
que.append((self.brother(cur),last_node))
cur = cur.father
return selected_nodes
使用numpy随机生成数据进行测试
尽管从图中看起来成功找到了最近的5个点,但是在没有经过大批量的数据测试,故而仅供参考。