点云 K 临近查找算法:kd 树

kd 树是一个二叉树,用于高效的查找某个点的 k 临近点,它的每一个节点记载了 特征坐标,切分轴,指向左右子树的指针。

1. 树的构建

首先是构建树的结点,左右指针使用列表来存放,这是为了后面计算方便。另外 __lt__ 方法用来重载小于号,用于在优先队列中进行比较。

class KDNode:
     def __init__(self, point, left=None, right=None, split_dim=None):
         self.point = point
         self.sons = [left, right]
         self.splitDim = split_dim

     def __lt__(self, other):
         if self.point[0] == other.point[0]:
             if self.point[1] == other.point[1]:
                 return self.point[2] < other.point[2]
             return self.point[1] < other.point[1]
         return self.point[0] < other.point[0]

然后是建树的步骤

  1. 选取一个维度 splitDim 进行切分
  2. 将所有的点按照 splitDim 的大小进行排序,并且选出中点 median 作为这个子树的根节点
  3. median 左侧的点放在左子树,右侧的点放在右子树,分别递归构建左右子树
def build(self, points=None, depth=0):
    """
    :param points: numpy
    """
    if points is None or points.shape[0] == 0:
        return None

    ndims = points.shape[1]
    splitDim = depth % ndims

    sortedIndexes = np.argsort(points[:, splitDim])
    points = points[sortedIndexes]

    mid = len(points) // 2
    median = points[mid]

    leftPoints = points[:mid]
    rightPoints = points[mid + 1:]

    left = self.build(leftPoints, depth + 1)
    right = self.build(rightPoints, depth + 1)

    kdn = self.KDNode(median, left, right, splitDim)
    if depth == 0:
        self.root = kdn
    return kdn

2. 查找

查找一个点 target 的 k 临近点坐标,步骤如下:

  1. 从根节点开始,计算根节点与 target 的距离,并将距离和该节点放入一个优先队列,这个队列用于存储当前与 target 最近的 k 个点,如果队列内元素数量大于 k,则 pop 出队首元素。
  2. 判断在当前的分割维度下,target 在根节点左侧还是右侧,根据这一点递归的查找 左 / 右 子树
  3. 在回溯阶段,判断优先队列内最远点(队首元素)到 target 的距离 dis1target 到分割线的距离 dis2 的大小,如果 dis1 大于等于 dis2,那么说明另一个子树上可能存在更近的点,于是递归搜索另一棵子树
def _search(self, node, target, heap):
    if node is None:
        return

    dist = self._cal_distance(node.point, target)
    heapq.heappush(heap, (-dist, node))

    if len(heap) > self.K:
        heapq.heappop(heap)

    splitDim = node.splitDim
    targetVal = target[splitDim]
    nodeVal = node.point[splitDim]

    choice = int(targetVal < nodeVal)
    self._search(node.sons[choice], target, heap)
    if abs(heap[0][0]) >= abs(nodeVal - targetVal):
        self._search(node.sons[1 ^ choice], target, heap)


def search_nearest(self, target):
    heap = []
    self._search(self.root, target, heap)
    points = [i.point for _, i in heap]
    return points

完整代码:

import heapq
import numpy as np


class KDTree:
    class KDNode:
        def __init__(self, point, left=None, right=None, split_dim=None):
            self.point = point
            self.sons = [left, right]
            self.splitDim = split_dim

        def __lt__(self, other):
            if self.point[0] == other.point[0]:
                if self.point[1] == other.point[1]:
                    return self.point[2] < other.point[2]
                return self.point[1] < other.point[1]
            return self.point[0] < other.point[0]

    def __init__(self, func, k=4):
        self.K = k
        self._cal_distance = func

    def build(self, points=None, depth=0):
        """

        :param points: numpy
        :param depth:
        :return:
        """
        if points is None or points.shape[0] == 0:
            return None

        ndims = points.shape[1]
        splitDim = depth % ndims

        sortedIndexes = np.argsort(points[:, splitDim])
        points = points[sortedIndexes]

        mid = len(points) // 2
        median = points[mid]

        leftPoints = points[:mid]
        rightPoints = points[mid + 1:]

        left = self.build(leftPoints, depth + 1)
        right = self.build(rightPoints, depth + 1)

        kdn = self.KDNode(median, left, right, splitDim)
        if depth == 0:
            self.root = kdn
        return kdn


    def _search(self, node, target, heap):
        if node is None:
            return

        dist = self._cal_distance(node.point, target)
        heapq.heappush(heap, (-dist, node))

        if len(heap) > self.K:
            heapq.heappop(heap)

        splitDim = node.splitDim
        targetVal = target[splitDim]
        nodeVal = node.point[splitDim]

        choice = int(targetVal < nodeVal)
        self._search(node.sons[choice], target, heap)
        if abs(heap[0][0]) >= abs(nodeVal - targetVal):
            self._search(node.sons[1 ^ choice], target, heap)


    def search_nearest(self, target):
        heap = []
        self._search(self.root, target, heap)
        points = [i.point for _, i in heap]
        return points
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SP FA

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值