对于代码这一块:
class Node:
def __init__(self, left, right, point_indices):
self.left = left
self.right = right
self.point_indices = point_indices
节点的数据结构为 左右节点 以及分割点。
# 构造树
# 按照某一个维度轮流切,切的位置取决于某个维度如(x)的排序后的中值
def make_kdtree(points,dim,i=0):
if len(points)>1:
points=points[np.argsort(points[:,i]),:]
i=(i+1)%dim
half=len(points)//2
return(make_kdtree(points[: half,:],dim,i),
make_kdtree(points[half+1 :,:],dim,i),
points[half,:])
elif len(points)==1:
return (None,None,points[0,:])
这里要注意建树过程中,建树中叶子节点的判定条件。我这里是最后只是一个点了,就是叶子节点。
建树完成后,可以执行最近点搜索:
def get_nearst(kd_node,point,dim,dist_func,return_distance=False,i=0,best=None):
if kd_node:
dist=dist_func(point,kd_node[2])
dx=kd_node[2][i]-point[i]
if not best:
best=[dist,kd_node[2]]
elif dist<best[0]:
best[0],best[1]=dist,kd_node[2]
i=(i+1)%dim
get_nearst(kd_node[dx<0],point,dim,dist_func,return_distance,i,best)
if dx**2<best[0]:
get_nearst(kd_node[dx>=0],point,dim,dist_func,return_distance,i,best)
return best if return_distance else best[1]
找最近的点的过程:计算该点到节点分割点的距离,以及该点到分割平面的距离(这个是判断是否搜索另外一个子空间节点的条件)
best这个数组就是存储节点到分割点距离,以及节点到分割平面的距离。
整个流程就是:首先计算当前节点的best ,搜索该节点在的子节点 ,判断该点到平面距离和点到分割点的距离,决定是否要搜索下一个子节点。再重复递归下去。
下面就是k 近邻的搜索了:
def get_knn(kd_node,point,k,dim,dist_func,return_distance=False,i=0,heap=None):
is_root= not heap
if is_root:
heap=[]
if kd_node:
dist=dist_func(point,np.array(kd_node[2]))
dx=kd_node[2][i]-point[i]
if len(heap)<k:
heapq.heappush(heap,(-dist,kd_node[2].tolist()))
elif dist<-heap[0][0]:
heapq.heappushpop(heap,(-dist,kd_node[2].tolist()))
i=(i+1)%dim
get_knn(kd_node[dx<0],point,k,dim,dist_func,return_distance,i,heap)
if dx**2<-heap[0][0] or len(heap)<k:
get_knn(kd_node[dx >= 0], point, k, dim, dist_func, return_distance, i, heap)
if is_root:
idx=np.argsort([-h[0] for h in heap])
neighors=[(-heap[n][0],np.array(heap[n][1])) for n in idx]
return neighors if return_distance else [n[1] for n in neighors]