效果说明:
- Input:输入Num个Dim维点的坐标,Points.size=(Num,Dim),输入一个目标点坐标Target、查找最近邻点数量K。
- Output: 求出距离Target最近的K个点的索引和距离。(具体坐标可由索引和Points列表获取)
- 环境要求: Python 3 with numpy and matplotlib
当Dim=2、Num=30、K=4时,绘制图如下:
输出:
candidate_index : [ 5 3 21 12 29 20]
candidate_distance : [0. 0.1107 0.1316 0.1701 0.2225 0.2656]
【注】这里以5号点作为目标点,它距离自己本身距离为0。
思路:
1、构建kdTree:通过递归构建一个二叉树,以当前空间维度的中位数点作为分割点,依次将空间分割,注意保存每个节点的坐标索引,以及由该节点划分出的左右节点序列和左右空间边界。
注意:这里的左右指的是每个维度的左右边界,默认:左小右大。
Node类参数说明:
这里没有将点的具体坐标信息赋予节点,而是保存节点对应的坐标索引,这样需要坐标值时根据索引调用坐标即可,也比较容易debug。
self.mid # 节点索引(中位数)
self.left # 节点左空间索引列表
self.right = right # 节点右空间索引列表
self.bound = bound # Dim * 2 # 当前节点所在空间范围(每个维度由左右边界控制)
self.flag = flag # 表示该节点对应的分割线应分割的维度索引(通过取模来控制变化)
self.lchild = lchild # 左子节点地址
self.rchild = rchild # 右子节点地址
self.par = par # 父节点地址
self.l_bound = l_bound # 节点左空间范围
self.r_bound = r_bound # 节点右空间范围
self.side = side # 当前节点是其父节点的左节点(0)或右节点(1)
2、确定初始节点(空间)
3、查找K近邻(具体详见参考书或与基础理论相关的博文)
# kd_Tree
# Edited By ocean_waver
import numpy as np
import matplotlib.pyplot as plt
class Node(object):
def __init__(self, mid, left, right, bound, flag, lchild=None, rchild=None, par=None,
l_bound=None, r_bound=None, side=-1):
self.mid = mid
self.left = left
self.right = right
self.bound = bound # Dim * 2
self.flag = flag
self.lchild = lchild
self.rchild = rchild
self.par = par
self.l_bound = l_bound
self.r_bound = r_bound
self.side = side
def find_median(a):
# s = np.sort(a)
arg_s = np.argsort(a)
idx_mid = arg_s[len(arg_s) // 2]
idx_left = np.array([arg_s[j] for j in range(0, len(arg_s) // 2)], dtype='int32')
idx_right = np.array([arg_s[j] for j in range(len(arg_s) // 2 + 1, np.size(a))], dtype='int32')
return idx_mid, idx_left, idx_right
def kd_tree_establish(root, points, dim):
# print(root.mid)
layer_flag = (root.flag + 1) % dim # 确定分割点对应的分割线的维度
if dim == 2:
static_pos = points[root.mid, root.flag]
if root.flag == 0:
x_line = np.linspace(static_pos,