Python 统计学习方法——kdTree实现K近邻搜索

效果说明:

  • Input:输入Num个Dim维点的坐标,Points.size=(Num,Dim),输入一个目标点坐标Target、查找最近邻点数量K。
  • Output: 求出距离Target最近的K个点的索引和距离。(具体坐标可由索引和Points列表获取)
  • 环境要求: Python 3 with numpy and matplotlib

当Dim=2、Num=30、K=4时,绘制图如下:
在这里插入图片描述

输出:
candidate_index : [ 5 3 21 12 29 20]
candidate_distance : [0. 0.1107 0.1316 0.1701 0.2225 0.2656]
【注】这里以5号点作为目标点,它距离自己本身距离为0。

思路:

1、构建kdTree:通过递归构建一个二叉树,以当前空间维度的中位数点作为分割点,依次将空间分割,注意保存每个节点的坐标索引,以及由该节点划分出的左右节点序列和左右空间边界。

注意:这里的左右指的是每个维度的左右边界,默认:左小右大。

Node类参数说明:
这里没有将点的具体坐标信息赋予节点,而是保存节点对应的坐标索引,这样需要坐标值时根据索引调用坐标即可,也比较容易debug。

self.mid 						# 节点索引(中位数)
self.left						# 节点左空间索引列表
self.right = right				# 节点右空间索引列表
self.bound = bound  # Dim * 2	# 当前节点所在空间范围(每个维度由左右边界控制)
self.flag = flag				# 表示该节点对应的分割线应分割的维度索引(通过取模来控制变化)
self.lchild = lchild			# 左子节点地址
self.rchild = rchild			# 右子节点地址
self.par = par					# 父节点地址
self.l_bound = l_bound			# 节点左空间范围
self.r_bound = r_bound			# 节点右空间范围
self.side = side				# 当前节点是其父节点的左节点(0)或右节点(1)
2、确定初始节点(空间)
3、查找K近邻(具体详见参考书或与基础理论相关的博文)
# kd_Tree
# Edited By ocean_waver
import numpy as np
import matplotlib.pyplot as plt


class Node(object):

    def __init__(self, mid, left, right, bound, flag, lchild=None, rchild=None, par=None,
                 l_bound=None, r_bound=None, side=-1):
        self.mid = mid
        self.left = left
        self.right = right
        self.bound = bound  # Dim * 2
        self.flag = flag
        self.lchild = lchild
        self.rchild = rchild
        self.par = par
        self.l_bound = l_bound
        self.r_bound = r_bound
        self.side = side


def find_median(a):
    # s = np.sort(a)
    arg_s = np.argsort(a)
    idx_mid = arg_s[len(arg_s) // 2]
    idx_left = np.array([arg_s[j] for j in range(0, len(arg_s) // 2)], dtype='int32')
    idx_right = np.array([arg_s[j] for j in range(len(arg_s) // 2 + 1, np.size(a))], dtype='int32')

    return idx_mid, idx_left, idx_right


def kd_tree_establish(root, points, dim):
    # print(root.mid)
    layer_flag = (root.flag + 1) % dim    # 确定分割点对应的分割线的维度

    if dim == 2:
        static_pos = points[root.mid, root.flag]
        if root.flag == 0:
            x_line = np.linspace(static_pos,
  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
K-D Tree(K-Dimensional Tree)是一种用于处理 k 维空间数据的数据结构。它可以用于搜索近邻点、范围搜索等问题。下面是一个简单的 Python 实现。 首先,我们需要一个节点类来表示树中的节点: ```python class Node: def __init__(self, point=None, axis=None, left=None, right=None): self.point = point # 节点代表的点 self.axis = axis # 划分轴 self.left = left # 左子树 self.right = right # 右子树 ``` 接下来,我们需要一个 `build_kdtree` 函数来构建 K-D Tree: ```python def build_kdtree(points, depth=0): n = len(points) if n <= 0: return None axis = depth % k # 选择轴 points.sort(key=lambda point: point[axis]) mid = n // 2 return Node( point=points[mid], axis=axis, left=build_kdtree(points[:mid], depth+1), right=build_kdtree(points[mid+1:], depth+1) ) ``` 这个函数会递归地将点集分成两部分,构建 K-D Tree。 接下来,我们需要实现两个重要的子函数:`nearest_neighbor` 和 `range_search`。 `nearest_neighbor` 函数用于查找 K-D Tree 中最近的点: ```python def nearest_neighbor(root, point): best = None best_distance = float("inf") def search(node, depth=0): nonlocal best, best_distance if node is None: return distance = euclidean_distance(point, node.point) if distance < best_distance: best = node.point best_distance = distance axis = depth % k if point[axis] < node.point[axis]: search(node.left, depth+1) if point[axis] + best_distance >= node.point[axis]: search(node.right, depth+1) else: search(node.right, depth+1) if point[axis] - best_distance <= node.point[axis]: search(node.left, depth+1) search(root) return best ``` 这个函数会从根节点开始递归地搜索 K-D Tree,找到与目标点最近的点。 最后,我们需要实现 `range_search` 函数来查找 K-D Tree 中给定范围内的所有点: ```python def range_search(root, lower, upper): result = [] def search(node, depth=0): if node is None: return if lower <= node.point <= upper: result.append(node.point) axis = depth % k if lower[axis] <= node.point[axis]: search(node.left, depth+1) if upper[axis] >= node.point[axis]: search(node.right, depth+1) search(root) return result ``` 这个函数会从根节点开始递归地搜索 K-D Tree,找到所有满足条件的点。 完整的代码如下所示: ```python import math k = 2 class Node: def __init__(self, point=None, axis=None, left=None, right=None): self.point = point # 节点代表的点 self.axis = axis # 划分轴 self.left = left # 左子树 self.right = right # 右子树 def euclidean_distance(p, q): return math.sqrt(sum((p[i] - q[i])**2 for i in range(k))) def build_kdtree(points, depth=0): n = len(points) if n <= 0: return None axis = depth % k # 选择轴 points.sort(key=lambda point: point[axis]) mid = n // 2 return Node( point=points[mid], axis=axis, left=build_kdtree(points[:mid], depth+1), right=build_kdtree(points[mid+1:], depth+1) ) def nearest_neighbor(root, point): best = None best_distance = float("inf") def search(node, depth=0): nonlocal best, best_distance if node is None: return distance = euclidean_distance(point, node.point) if distance < best_distance: best = node.point best_distance = distance axis = depth % k if point[axis] < node.point[axis]: search(node.left, depth+1) if point[axis] + best_distance >= node.point[axis]: search(node.right, depth+1) else: search(node.right, depth+1) if point[axis] - best_distance <= node.point[axis]: search(node.left, depth+1) search(root) return best def range_search(root, lower, upper): result = [] def search(node, depth=0): if node is None: return if lower <= node.point <= upper: result.append(node.point) axis = depth % k if lower[axis] <= node.point[axis]: search(node.left, depth+1) if upper[axis] >= node.point[axis]: search(node.right, depth+1) search(root) return result ``` 这是一个简单的 K-D Tree 的 Python 实现,可以用来解决搜索近邻点、范围搜索等问题。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值