三维点云课程(二)——最邻近问题

目录

1 Nearest Neighbor (NN) Problem

1.1 方法介绍

1.2 BST(二叉树)

1.3 KD树

1.3.1 给出a到i二维点,怎么建立KD树

1.3.2 加速KD树

1.3.3 KD树KNN

1.4 八叉树

1.4.1 建立八叉树步骤

1.4.2 八叉树KNN查找

1 Nearest Neighbor (NN) Problem

BST——维分割

KD-tree——任意维度分割

Octree——三维分割

1.1 方法介绍

 点云最邻近问题难点:1. 不规则  2. 三维  3. 数据量大

算法核心点: 1. 分隔空间   2. 跳过空间  3. 停止搜索

1.2 BST(二叉树)

左放小,大放右

  •  二叉树代码

  • 深度复杂度,尽量平衡

  •  遍历方法

  • 临近搜寻KNN Search

 找到worst distance是最重要的,其是动态的

  • RNN最临近

1.3 KD树

在每个维度应用一下二叉树就是KD树

1.3.1 给出a到i二维点,怎么建立KD树

1、随意找一个方向切

 2、按维度轮流切,切到leaf_size=1,即区域中只有一个点。每一层需要排序找中间点进行切割。

1.3.2 加速KD树

 如果不需要KD树完全平衡。1. 选取一部分点找中值点。2. 用平均值带代替中值

1.3.3 KD树KNN

用 KD树来进行KNN搜索的核心思想是:
给定一个区域,用最坏距离的概念判断要不要对这个区域进行搜索,由此来节省搜索区域

判断需要查找的条件:
1. 点本身在这个区域中(距离为0)
2. 点离这个区域的边界距离小于 最坏距离

# kdtree的具体实现,包括构建和查找
 
import random
import math
import numpy as np
 
from result_set import KNNResultSet, RadiusNNResultSet
 
# Node类,Node是tree的基本组成元素
class Node:
    def __init__(self, axis, value, left, right, point_indices):
        self.axis = axis                        #以哪个维度去切
        self.value = value                      #维度的值
        self.left = left                        #左节点
        self.right = right#右节点
        self.point_indices = point_indices#这一片区域存储的值(切之前)
 
#判断是否是叶子
    def is_leaf(self):
        if self.value is None:
            return True
        else:
            return False
 
    def __str__(self):
        output = ''
        output += 'axis %d, ' % self.axis
        if self.value is None:
            output += 'split value: leaf, '
        else:
            output += 'split value: %.2f, ' % self.value
        output += 'point_indices: '
        output += str(self.point_indices.tolist())
        return output
 
# 功能:构建树之前需要对value进行排序,同时对一个的key的顺序也要跟着改变
# 输入:
#     key:键
#     value:值
# 输出:
#     key_sorted:排序后的索引
#     value_sorted:排序后的值
def sort_key_by_vale(key, value):
    assert key.shape == value.shape
    assert len(key.shape) == 1
    sorted_idx = np.argsort(value)
    key_sorted = key[sorted_idx]
    value_sorted = value[sorted_idx]
    return key_sorted, value_sorted
 
 
def axis_round_robin(axis, dim):
    if axis == dim-1:
        return 0
    else:
        return axis + 1
 
# 功能:通过递归的方式构建树
# 输入:
#     root: 树的根节点
#     db: 点云数据
#     point_indices:排序后的键
#     axis: scalar  维度
#     leaf_size: scalar A leaf node can contain more than 1 point
# 输出:
#     root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
    #处理根节点 初始化根节点
    if root is None:
        root = Node(axis, None, None, None, point_indices)
 
    # determine whether to split into left and right
    if len(point_indices) > leaf_size: #如果区域的点大于leaf_size
        #point_indices  db某一个维度从小到大排列的下标
        point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis])  # M
 
        # 作业1
        # 屏蔽开始
        #ceil是向上取整  找到左边的点
        middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1 #某一维度中间的索引
        #print(middle_left_idx)
        middle_left_point_idx = point_indices_sorted[middle_left_idx]   # 某一维度中间的值
        #print(middle_left_point_idx)
        middle_left_point_value = db[middle_left_point_idx,axis]   #中间节点
        #print("middle_left_point_value",middle_left_point_value)
        #右边
        middle_right_idx = middle_left_idx + 1
        middle_right_point_idx = point_indices_sorted[middle_right_idx]
        middle_right_point_value = db[middle_right_point_idx,axis]
 
        root.value = (middle_left_point_value + middle_right_point_value) * 0.5
        # 如果分割后的左边或右边区域点数量大于leaf_size,再次分割
        root.left = kdtree_recursive_build(root.left,db,point_indices_sorted[0:middle_right_idx],
                                           axis_round_robin(axis,dim=db.shape[1]),leaf_size)
 
        root.right = kdtree_recursive_build(root.right,db,point_indices_sorted[middle_right_idx:],
                                           axis_round_robin(axis,dim=db.shape[1]),leaf_size)
        # 屏蔽结束
    return root
 
 
# 功能:翻转一个kd树
# 输入:
#     root:kd树
#     depth: 当前深度
#     max_depth:最大深度
def traverse_kdtree(root: Node, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]
 
    if root.is_leaf():
        print(root)
    else:
        traverse_kdtree(root.left, depth, max_depth)
        traverse_kdtree(root.right, depth, max_depth)
 
    depth[0] -= 1
 
# 功能:1、构建kd树(利用kdtree_recursive_build功能函数实现的对外接口)
# 输入:
#     db_np:原始数据
#     leaf_size:scale
# 输出:
#     root:构建完成的kd树
def kdtree_construction(db_np, leaf_size):
    N, dim = db_np.shape[0], db_np.shape[1] #个数与维度
 
    # build kd_tree recursively
    root = None
    root = kdtree_recursive_build(root,
                                  db_np,
                                  np.arange(N),
                                  axis=0,
                                  leaf_size=leaf_size)
    return root
 
 
# 功能:通过kd树实现knn搜索,即找出最近的k个近邻
# 输入:
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query:索引信息
# 输出:
#     搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False
 
    # # leaf_points.shape: (124668, 3)
    # leaf_points = db[root.point_indices, :]
    # print("leaf_points.shape:",leaf_points.shape)
    # # query是一维数组 query.shape: (3,)
    # print("query.shape:", query.shape)
    # # 添加一个维度,转成:query.shape: (1, 3)
    # print("query.shape:",np.expand_dims(query, 0).shape)
 
    #比较子节点中的每一个点,
    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        # print("query:",query)
        # print("leaf_points:", leaf_points)
        #axis=1表示按行向量处理,求多个行向量的范数,默认是二范数
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False
 
    # 作业2
    # 提示:仍通过递归的方式实现搜索
    # 屏蔽开始
    #在左边则左边是必须要递归搜索的,如果切分线的距离大于最差距离,则右边不需要搜索了
    if query[root.axis] <= root.value:
        kdtree_knn_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_knn_search(root.right,db,result_set,query)
    else:
        kdtree_knn_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_knn_search(root.right,db,result_set,query)
    # 屏蔽结束
 
    return False
 
# 功能:通过kd树实现radius搜索,即找出距离radius以内的近邻
# 输入:
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query:索引信息
# 输出:
#     搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False
 
    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False
    
    # 作业3
    # 提示:通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis] <= root.value:
        kdtree_radius_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_radius_search(root.right,db,result_set,query)
    else:
        kdtree_radius_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_radius_search(root.left,db,result_set,query)
 
    # 屏蔽结束
 
    return False
 
 
 
def main():
    # configuration
    db_size = 64
    dim = 3
    leaf_size = 4 #小于4个点不再分割
    k = 1
 
    db_np = np.random.rand(db_size, dim) #随机产生64个三维的点
    #print(db_np)
    # 开始建立KD树
    root = kdtree_construction(db_np, leaf_size=leaf_size)
 
    depth = [0]
    max_depth = [0]
    traverse_kdtree(root, depth, max_depth)
    print("tree max depth: %d" % max_depth[0])
 
if __name__ == '__main__':
    main()

1.4 八叉树

专门为三维搜索建立的数据结构,与KD树的区别是一次运用多维度信息进行切割。

1.4.1 建立八叉树步骤

以二维的四叉树来简单说明

 所有维度同时均匀切割,如果不满足约束则进行切割

 约束:1. 最小节点leaf_size = 1    2. 最小边长min_extent

假设有很多重合点,如果leaf_size设置不好,会造成死循环。

1.4.2 八叉树KNN查找

只需要搜索 S2根节点内的子节点,可以提前终止节点。

红色点为搜寻点,首先找到最近的根节点s2查询所有子节点,搜寻到s8中的点e确定最坏距离,再查询最坏距离内的节点s5,再搜索s5子节点,找到s12中的a,来判断更新最坏距离。

  • overlap函数

 并不是所有子节点都会被查询,而是看节点区域与最坏距离为半径的球是否有overlap。即判断球与立方体是否有交集。一共有三种方法判断

1. 离得太远   红+蓝<绿

2. 离的不远的基础上 球跟一个面有接触   红+蓝<绿      球的中心在正方体范围内

 3. 球跟正方体棱接触  

  •  包含
# octree的具体实现,包括构建和查找
 
import random
import math
import numpy as np
import time
 
from result_set import KNNResultSet, RadiusNNResultSet
 
# 节点,构成OCtree的基本元素
class Octant:
    #children:Array of length 8
    #center:Center of the cube
    #extent:半个边长
    def __init__(self, children, center, extent, point_indices, is_leaf):
        self.children = children
        self.center = center
        self.extent = extent
        self.point_indices = point_indices
        self.is_leaf = is_leaf
 
    def __str__(self):
        output = ''
        output += 'center: [%.2f, %.2f, %.2f], ' % (self.center[0], self.center[1], self.center[2])
        output += 'extent: %.2f, ' % self.extent
        output += 'is_leaf: %d, ' % self.is_leaf
        output += 'children: ' + str([x is not None for x in self.children]) + ", "
        output += 'point_indices: ' + str(self.point_indices)
        return output
 
# 功能:翻转octree
# 输入:
#     root: 构建好的octree
#     depth: 当前深度
#     max_depth:最大深度
def traverse_octree(root: Octant, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]
 
    if root is None:
        pass
    elif root.is_leaf:
        print(root)
    else:
        for child in root.children:
            traverse_octree(child, depth, max_depth)
    depth[0] -= 1
 
# 功能:通过递归的方式构建octree
# 输入:
#     root:根节点
#     db:原始数据
#     center: 中心
#     extent: 当前分割区间大小
#     point_indices: 点的key
#     leaf_size: scale
#     min_extent: 最小分割区间
def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent):
    if len(point_indices) == 0:
        return None
 
    if root is None:
        root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)
 
    # determine whether to split this octant
    #点的数量小于叶子点的数量
    if len(point_indices) <= leaf_size or extent <= min_extent:
        root.is_leaf = True
    else:
        # 作业4
        # 屏蔽开始
        root.is_leaf = False
        children_point_indices = [ [] for i in range(8)]
        for point_idx in point_indices:
            point_db = db[point_idx]
            morton_code = 0
            if(point_db[0] > center[0]):
                morton_code = morton_code | 1
            if point_db[1] > center[1]:
                morton_code = morton_code | 2
            if point_db[2] > center[2]:
                morton_code = morton_code | 4
            children_point_indices[morton_code].append(point_idx)
 
        factor = [0.5,0.5]
        for i in range(8):
            child_center_x = center[0] + factor[(i & 1) > 0] * extent
            child_center_y = center[1] + factor[(i & 2) > 0] * extent
            child_center_z = center[2] + factor[(i & 4) > 0] * extent
            child_extent = 0.5 * extent
            child_center = np.asarray([child_center_x,child_center_y,child_center_z])
            root.children[i] = octree_recursive_build(root.children[i],
                                                      db,
                                                      child_center,
                                                      child_extent,
                                                      children_point_indices[i],
                                                      leaf_size,
                                                      min_extent)
        # 屏蔽结束
    return root
 
# 功能:判断当前query区间是否在octant内
# 输入:
#     query: 索引信息
#     radius:索引半径
#     octant:octree
# 输出:
#     判断结果,即True/False
def inside(query: np.ndarray, radius: float, octant:Octant):
    """
    Determines if the query ball is inside the octant
    :param query:
    :param radius:
    :param octant:
    :return:
    """
    query_offset = query - octant.center
    query_offset_abs = np.fabs(query_offset)
    possible_space = query_offset_abs + radius
    return np.all(possible_space < octant.extent)
 
# 功能:判断当前query区间是否和octant有重叠部分
# 输入:
#     query: 索引信息
#     radius:索引半径
#     octant:octree
# 输出:
#     判断结果,即True/False
def overlaps(query: np.ndarray, radius: float, octant:Octant):
    """
    Determines if the query ball overlaps with the octant
    :param query:
    :param radius:
    :param octant:
    :return:
    """
    query_offset = query - octant.center
    #球心到中心区域的距离
    query_offset_abs = np.fabs(query_offset)
 
    # completely outside, since query is outside the relevant area
    #半径+长度
    max_dist = radius + octant.extent
    if np.any(query_offset_abs > max_dist):
        return False
 
    # if pass the above check, consider the case that the ball is contacting the face of the octant
    if np.sum((query_offset_abs < octant.extent).astype(np.int)) >= 2:
        return True
 
    # conside the case that the ball is contacting the edge or corner of the octant
    # since the case of the ball center (query) inside octant has been considered,
    # we only consider the ball center (query) outside octant
    x_diff = max(query_offset_abs[0] - octant.extent, 0)
    y_diff = max(query_offset_abs[1] - octant.extent, 0)
    z_diff = max(query_offset_abs[2] - octant.extent, 0)
 
    return x_diff * x_diff + y_diff * y_diff + z_diff * z_diff < radius * radius
 
 
# 功能:判断当前query是否包含octant
# 输入:
#     query: 索引信息
#     radius:索引半径
#     octant:octree
# 输出:
#     判断结果,即True/False
def contains(query: np.ndarray, radius: float, octant:Octant):
    """
    Determine if the query ball contains the octant
    :param query:
    :param radius:
    :param octant:
    :return:
    """
    query_offset = query - octant.center
    query_offset_abs = np.fabs(query_offset)
 
    query_offset_to_farthest_corner = query_offset_abs + octant.extent
    return np.linalg.norm(query_offset_to_farthest_corner) < radius
 
# 功能:在octree中查找信息
# 输入:
#    root: octree
#    db:原始数据
#    result_set: 索引结果
#    query:索引信息
def octree_radius_search_fast(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False
 
    # 作业5
    # 提示:尽量利用上面的inside、overlaps、contains等函数
    # 屏蔽开始
    # 作业6
    # 屏蔽开始
    if contains(query, result_set.worstDist(), root):
        # compare the contents of the octant
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # don't need to check any child
        return False
 
    if root.is_leaf and len(root.point_indices) > 0:
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # check whether we can stop search now
        return inside(query, result_set.worstDist(), root)
 
        # no need to go to most relevant child first, because anyway we will go through all children
    for c, child in enumerate(root.children):
        if child is None:
            continue
        if False == overlaps(query, result_set.worstDist(), child):
            continue
        if octree_radius_search_fast(child, db, result_set, query):
            return True
 
    # 屏蔽结束
 
    return inside(query, result_set.worstDist(), root)
 
 
# 功能:在octree中查找radius范围内的近邻
# 输入:
#     root: octree
#     db: 原始数据
#     result_set: 搜索结果
#     query: 搜索信息
def octree_radius_search(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False
 
    if root.is_leaf and len(root.point_indices) > 0:
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # check whether we can stop search now
        return inside(query, result_set.worstDist(), root)
 
    # 作业6
    # 屏蔽开始
    morton_code = 0
    if query[0] > root.center[0]:
        morton_code = morton_code | 1
    if query[1] > root.center[1]:
        morton_code = morton_code | 2
    if query[2] > root.center[2]:
        morton_code = morton_code | 4
 
    if octree_knn_search(root.children[morton_code], db, result_set, query):
        return True
 
    # 找其他子节点
    # c:当前的位置,child:当前的结点
    for c, child in enumerate(root.children):
        if c == morton_code or child is None:
            continue
        if False == overlaps(query, result_set.worstDist(), child):
            continue
 
        if octree_knn_search(child, db, result_set, query):
            return True
    # 屏蔽结束
 
    # final check of if we can stop search
    return inside(query, result_set.worstDist(), root)
 
# 功能:在octree中查找最近的k个近邻
# 输入:
#     root: octree
#     db: 原始数据
#     result_set: 搜索结果
#     query: 搜索信息
def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False
 
    if root.is_leaf and len(root.point_indices) > 0:
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # check whether we can stop search now
        return inside(query, result_set.worstDist(), root)
 
    # 作业7
    # 屏蔽开始
    morton_code = 0
    if query[0] > root.center[0]:
        morton_code = morton_code | 1
    if query[1] > root.center[1]:
        morton_code = morton_code | 2
    if query[2] > root.center[2]:
        morton_code = morton_code | 4
 
    if octree_knn_search(root.children[morton_code],db,result_set,query):
        return True
 
    #找其他子节点
    # c:当前的位置,child:当前的结点
    for c,child in enumerate(root.children):
        if c == morton_code or child is None:
            continue
        if False == overlaps(query,result_set.worstDist(),child):
            continue
 
        if octree_knn_search(child,db,result_set,query):
            return True
    # 屏蔽结束
 
    # final check of if we can stop search
    return inside(query, result_set.worstDist(), root)
 
# 功能:构建octree,即通过调用octree_recursive_build函数实现对外接口
# 输入:
#    dp_np: 原始数据
#    leaf_size:scale
#    min_extent:最小划分区间
def octree_construction(db_np, leaf_size, min_extent):
    N, dim = db_np.shape[0], db_np.shape[1]
    db_np_min = np.amin(db_np, axis=0)
    db_np_max = np.amax(db_np, axis=0)
    db_extent = np.max(db_np_max - db_np_min) * 0.5
    db_center = np.mean(db_np, axis=0)
 
    root = None
    root = octree_recursive_build(root, db_np, db_center, db_extent, list(range(N)),
                                  leaf_size, min_extent)
 
    return root
 
def main():
    # configuration
    db_size = 64000
    dim = 3
    leaf_size = 4
    min_extent = 0.0001
    k = 8
 
    db_np = np.random.rand(db_size, dim)
 
    root = octree_construction(db_np, leaf_size, min_extent)
 
    # depth = [0]
    # max_depth = [0]
    # traverse_octree(root, depth, max_depth)
    # print("tree max depth: %d" % max_depth[0])
 
    # query = np.asarray([0, 0, 0])
    # result_set = KNNResultSet(capacity=k)
    # octree_knn_search(root, db_np, result_set, query)
    # print(result_set)
    #
    # diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
    # nn_idx = np.argsort(diff)
    # nn_dist = diff[nn_idx]
    # print(nn_idx[0:k])
    # print(nn_dist[0:k])
 
    begin_t = time.time()
    print("Radius search normal:")
    for i in range(100):
        query = np.random.rand(3)
        result_set = RadiusNNResultSet(radius=0.5)
        octree_radius_search(root, db_np, result_set, query)
    # print(result_set)
    print("Search takes %.3fms\n" % ((time.time() - begin_t) * 1000))
 
    begin_t = time.time()
    print("Radius search fast:")
    for i in range(100):
        query = np.random.rand(3)
        result_set = RadiusNNResultSet(radius = 0.5)
        octree_radius_search_fast(root, db_np, result_set, query)
    # print(result_set)
    print("Search takes %.3fms\n" % ((time.time() - begin_t)*1000))
 
 
 
if __name__ == '__main__':
    main()
  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

桦树无泪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值