目录
1 Nearest Neighbor (NN) Problem
1 Nearest Neighbor (NN) Problem
BST——维分割
KD-tree——任意维度分割
Octree——三维分割
1.1 方法介绍
点云最邻近问题难点:1. 不规则 2. 三维 3. 数据量大
算法核心点: 1. 分隔空间 2. 跳过空间 3. 停止搜索
1.2 BST(二叉树)
左放小,大放右
- 二叉树代码
- 深度复杂度,尽量平衡
- 遍历方法
- 临近搜寻KNN Search
找到worst distance是最重要的,其是动态的
- RNN最临近
1.3 KD树
在每个维度应用一下二叉树就是KD树
1.3.1 给出a到i二维点,怎么建立KD树
1、随意找一个方向切
2、按维度轮流切,切到leaf_size=1,即区域中只有一个点。每一层需要排序找中间点进行切割。
1.3.2 加速KD树
如果不需要KD树完全平衡。1. 选取一部分点找中值点。2. 用平均值带代替中值
1.3.3 KD树KNN
用 KD树来进行KNN搜索的核心思想是:
给定一个区域,用最坏距离的概念判断要不要对这个区域进行搜索,由此来节省搜索区域
判断需要查找的条件:
1. 点本身在这个区域中(距离为0)
2. 点离这个区域的边界距离小于 最坏距离
# kdtree的具体实现,包括构建和查找
import random
import math
import numpy as np
from result_set import KNNResultSet, RadiusNNResultSet
# Node类,Node是tree的基本组成元素
class Node:
def __init__(self, axis, value, left, right, point_indices):
self.axis = axis #以哪个维度去切
self.value = value #维度的值
self.left = left #左节点
self.right = right#右节点
self.point_indices = point_indices#这一片区域存储的值(切之前)
#判断是否是叶子
def is_leaf(self):
if self.value is None:
return True
else:
return False
def __str__(self):
output = ''
output += 'axis %d, ' % self.axis
if self.value is None:
output += 'split value: leaf, '
else:
output += 'split value: %.2f, ' % self.value
output += 'point_indices: '
output += str(self.point_indices.tolist())
return output
# 功能:构建树之前需要对value进行排序,同时对一个的key的顺序也要跟着改变
# 输入:
# key:键
# value:值
# 输出:
# key_sorted:排序后的索引
# value_sorted:排序后的值
def sort_key_by_vale(key, value):
assert key.shape == value.shape
assert len(key.shape) == 1
sorted_idx = np.argsort(value)
key_sorted = key[sorted_idx]
value_sorted = value[sorted_idx]
return key_sorted, value_sorted
def axis_round_robin(axis, dim):
if axis == dim-1:
return 0
else:
return axis + 1
# 功能:通过递归的方式构建树
# 输入:
# root: 树的根节点
# db: 点云数据
# point_indices:排序后的键
# axis: scalar 维度
# leaf_size: scalar A leaf node can contain more than 1 point
# 输出:
# root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
#处理根节点 初始化根节点
if root is None:
root = Node(axis, None, None, None, point_indices)
# determine whether to split into left and right
if len(point_indices) > leaf_size: #如果区域的点大于leaf_size
#point_indices db某一个维度从小到大排列的下标
point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis]) # M
# 作业1
# 屏蔽开始
#ceil是向上取整 找到左边的点
middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1 #某一维度中间的索引
#print(middle_left_idx)
middle_left_point_idx = point_indices_sorted[middle_left_idx] # 某一维度中间的值
#print(middle_left_point_idx)
middle_left_point_value = db[middle_left_point_idx,axis] #中间节点
#print("middle_left_point_value",middle_left_point_value)
#右边
middle_right_idx = middle_left_idx + 1
middle_right_point_idx = point_indices_sorted[middle_right_idx]
middle_right_point_value = db[middle_right_point_idx,axis]
root.value = (middle_left_point_value + middle_right_point_value) * 0.5
# 如果分割后的左边或右边区域点数量大于leaf_size,再次分割
root.left = kdtree_recursive_build(root.left,db,point_indices_sorted[0:middle_right_idx],
axis_round_robin(axis,dim=db.shape[1]),leaf_size)
root.right = kdtree_recursive_build(root.right,db,point_indices_sorted[middle_right_idx:],
axis_round_robin(axis,dim=db.shape[1]),leaf_size)
# 屏蔽结束
return root
# 功能:翻转一个kd树
# 输入:
# root:kd树
# depth: 当前深度
# max_depth:最大深度
def traverse_kdtree(root: Node, depth, max_depth):
depth[0] += 1
if max_depth[0] < depth[0]:
max_depth[0] = depth[0]
if root.is_leaf():
print(root)
else:
traverse_kdtree(root.left, depth, max_depth)
traverse_kdtree(root.right, depth, max_depth)
depth[0] -= 1
# 功能:1、构建kd树(利用kdtree_recursive_build功能函数实现的对外接口)
# 输入:
# db_np:原始数据
# leaf_size:scale
# 输出:
# root:构建完成的kd树
def kdtree_construction(db_np, leaf_size):
N, dim = db_np.shape[0], db_np.shape[1] #个数与维度
# build kd_tree recursively
root = None
root = kdtree_recursive_build(root,
db_np,
np.arange(N),
axis=0,
leaf_size=leaf_size)
return root
# 功能:通过kd树实现knn搜索,即找出最近的k个近邻
# 输入:
# root: kd树
# db: 原始数据
# result_set:搜索结果
# query:索引信息
# 输出:
# 搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
# # leaf_points.shape: (124668, 3)
# leaf_points = db[root.point_indices, :]
# print("leaf_points.shape:",leaf_points.shape)
# # query是一维数组 query.shape: (3,)
# print("query.shape:", query.shape)
# # 添加一个维度,转成:query.shape: (1, 3)
# print("query.shape:",np.expand_dims(query, 0).shape)
#比较子节点中的每一个点,
if root.is_leaf():
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
# print("query:",query)
# print("leaf_points:", leaf_points)
#axis=1表示按行向量处理,求多个行向量的范数,默认是二范数
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# 作业2
# 提示:仍通过递归的方式实现搜索
# 屏蔽开始
#在左边则左边是必须要递归搜索的,如果切分线的距离大于最差距离,则右边不需要搜索了
if query[root.axis] <= root.value:
kdtree_knn_search(root.left,db,result_set,query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_knn_search(root.right,db,result_set,query)
else:
kdtree_knn_search(root.right,db,result_set,query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_knn_search(root.right,db,result_set,query)
# 屏蔽结束
return False
# 功能:通过kd树实现radius搜索,即找出距离radius以内的近邻
# 输入:
# root: kd树
# db: 原始数据
# result_set:搜索结果
# query:索引信息
# 输出:
# 搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf():
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# 作业3
# 提示:通过递归的方式实现搜索
# 屏蔽开始
if query[root.axis] <= root.value:
kdtree_radius_search(root.left,db,result_set,query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_radius_search(root.right,db,result_set,query)
else:
kdtree_radius_search(root.right,db,result_set,query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_radius_search(root.left,db,result_set,query)
# 屏蔽结束
return False
def main():
# configuration
db_size = 64
dim = 3
leaf_size = 4 #小于4个点不再分割
k = 1
db_np = np.random.rand(db_size, dim) #随机产生64个三维的点
#print(db_np)
# 开始建立KD树
root = kdtree_construction(db_np, leaf_size=leaf_size)
depth = [0]
max_depth = [0]
traverse_kdtree(root, depth, max_depth)
print("tree max depth: %d" % max_depth[0])
if __name__ == '__main__':
main()
1.4 八叉树
专门为三维搜索建立的数据结构,与KD树的区别是一次运用多维度信息进行切割。
1.4.1 建立八叉树步骤
以二维的四叉树来简单说明
所有维度同时均匀切割,如果不满足约束则进行切割
约束:1. 最小节点leaf_size = 1 2. 最小边长min_extent
假设有很多重合点,如果leaf_size设置不好,会造成死循环。
1.4.2 八叉树KNN查找
只需要搜索 S2根节点内的子节点,可以提前终止节点。
红色点为搜寻点,首先找到最近的根节点s2查询所有子节点,搜寻到s8中的点e确定最坏距离,再查询最坏距离内的节点s5,再搜索s5子节点,找到s12中的a,来判断更新最坏距离。
- overlap函数
并不是所有子节点都会被查询,而是看节点区域与最坏距离为半径的球是否有overlap。即判断球与立方体是否有交集。一共有三种方法判断
1. 离得太远 红+蓝<绿
2. 离的不远的基础上 球跟一个面有接触 红+蓝<绿 球的中心在正方体范围内
3. 球跟正方体棱接触
- 包含
# octree的具体实现,包括构建和查找
import random
import math
import numpy as np
import time
from result_set import KNNResultSet, RadiusNNResultSet
# 节点,构成OCtree的基本元素
class Octant:
#children:Array of length 8
#center:Center of the cube
#extent:半个边长
def __init__(self, children, center, extent, point_indices, is_leaf):
self.children = children
self.center = center
self.extent = extent
self.point_indices = point_indices
self.is_leaf = is_leaf
def __str__(self):
output = ''
output += 'center: [%.2f, %.2f, %.2f], ' % (self.center[0], self.center[1], self.center[2])
output += 'extent: %.2f, ' % self.extent
output += 'is_leaf: %d, ' % self.is_leaf
output += 'children: ' + str([x is not None for x in self.children]) + ", "
output += 'point_indices: ' + str(self.point_indices)
return output
# 功能:翻转octree
# 输入:
# root: 构建好的octree
# depth: 当前深度
# max_depth:最大深度
def traverse_octree(root: Octant, depth, max_depth):
depth[0] += 1
if max_depth[0] < depth[0]:
max_depth[0] = depth[0]
if root is None:
pass
elif root.is_leaf:
print(root)
else:
for child in root.children:
traverse_octree(child, depth, max_depth)
depth[0] -= 1
# 功能:通过递归的方式构建octree
# 输入:
# root:根节点
# db:原始数据
# center: 中心
# extent: 当前分割区间大小
# point_indices: 点的key
# leaf_size: scale
# min_extent: 最小分割区间
def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent):
if len(point_indices) == 0:
return None
if root is None:
root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)
# determine whether to split this octant
#点的数量小于叶子点的数量
if len(point_indices) <= leaf_size or extent <= min_extent:
root.is_leaf = True
else:
# 作业4
# 屏蔽开始
root.is_leaf = False
children_point_indices = [ [] for i in range(8)]
for point_idx in point_indices:
point_db = db[point_idx]
morton_code = 0
if(point_db[0] > center[0]):
morton_code = morton_code | 1
if point_db[1] > center[1]:
morton_code = morton_code | 2
if point_db[2] > center[2]:
morton_code = morton_code | 4
children_point_indices[morton_code].append(point_idx)
factor = [0.5,0.5]
for i in range(8):
child_center_x = center[0] + factor[(i & 1) > 0] * extent
child_center_y = center[1] + factor[(i & 2) > 0] * extent
child_center_z = center[2] + factor[(i & 4) > 0] * extent
child_extent = 0.5 * extent
child_center = np.asarray([child_center_x,child_center_y,child_center_z])
root.children[i] = octree_recursive_build(root.children[i],
db,
child_center,
child_extent,
children_point_indices[i],
leaf_size,
min_extent)
# 屏蔽结束
return root
# 功能:判断当前query区间是否在octant内
# 输入:
# query: 索引信息
# radius:索引半径
# octant:octree
# 输出:
# 判断结果,即True/False
def inside(query: np.ndarray, radius: float, octant:Octant):
"""
Determines if the query ball is inside the octant
:param query:
:param radius:
:param octant:
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
possible_space = query_offset_abs + radius
return np.all(possible_space < octant.extent)
# 功能:判断当前query区间是否和octant有重叠部分
# 输入:
# query: 索引信息
# radius:索引半径
# octant:octree
# 输出:
# 判断结果,即True/False
def overlaps(query: np.ndarray, radius: float, octant:Octant):
"""
Determines if the query ball overlaps with the octant
:param query:
:param radius:
:param octant:
:return:
"""
query_offset = query - octant.center
#球心到中心区域的距离
query_offset_abs = np.fabs(query_offset)
# completely outside, since query is outside the relevant area
#半径+长度
max_dist = radius + octant.extent
if np.any(query_offset_abs > max_dist):
return False
# if pass the above check, consider the case that the ball is contacting the face of the octant
if np.sum((query_offset_abs < octant.extent).astype(np.int)) >= 2:
return True
# conside the case that the ball is contacting the edge or corner of the octant
# since the case of the ball center (query) inside octant has been considered,
# we only consider the ball center (query) outside octant
x_diff = max(query_offset_abs[0] - octant.extent, 0)
y_diff = max(query_offset_abs[1] - octant.extent, 0)
z_diff = max(query_offset_abs[2] - octant.extent, 0)
return x_diff * x_diff + y_diff * y_diff + z_diff * z_diff < radius * radius
# 功能:判断当前query是否包含octant
# 输入:
# query: 索引信息
# radius:索引半径
# octant:octree
# 输出:
# 判断结果,即True/False
def contains(query: np.ndarray, radius: float, octant:Octant):
"""
Determine if the query ball contains the octant
:param query:
:param radius:
:param octant:
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
query_offset_to_farthest_corner = query_offset_abs + octant.extent
return np.linalg.norm(query_offset_to_farthest_corner) < radius
# 功能:在octree中查找信息
# 输入:
# root: octree
# db:原始数据
# result_set: 索引结果
# query:索引信息
def octree_radius_search_fast(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
# 作业5
# 提示:尽量利用上面的inside、overlaps、contains等函数
# 屏蔽开始
# 作业6
# 屏蔽开始
if contains(query, result_set.worstDist(), root):
# compare the contents of the octant
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# don't need to check any child
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# no need to go to most relevant child first, because anyway we will go through all children
for c, child in enumerate(root.children):
if child is None:
continue
if False == overlaps(query, result_set.worstDist(), child):
continue
if octree_radius_search_fast(child, db, result_set, query):
return True
# 屏蔽结束
return inside(query, result_set.worstDist(), root)
# 功能:在octree中查找radius范围内的近邻
# 输入:
# root: octree
# db: 原始数据
# result_set: 搜索结果
# query: 搜索信息
def octree_radius_search(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# 作业6
# 屏蔽开始
morton_code = 0
if query[0] > root.center[0]:
morton_code = morton_code | 1
if query[1] > root.center[1]:
morton_code = morton_code | 2
if query[2] > root.center[2]:
morton_code = morton_code | 4
if octree_knn_search(root.children[morton_code], db, result_set, query):
return True
# 找其他子节点
# c:当前的位置,child:当前的结点
for c, child in enumerate(root.children):
if c == morton_code or child is None:
continue
if False == overlaps(query, result_set.worstDist(), child):
continue
if octree_knn_search(child, db, result_set, query):
return True
# 屏蔽结束
# final check of if we can stop search
return inside(query, result_set.worstDist(), root)
# 功能:在octree中查找最近的k个近邻
# 输入:
# root: octree
# db: 原始数据
# result_set: 搜索结果
# query: 搜索信息
def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# 作业7
# 屏蔽开始
morton_code = 0
if query[0] > root.center[0]:
morton_code = morton_code | 1
if query[1] > root.center[1]:
morton_code = morton_code | 2
if query[2] > root.center[2]:
morton_code = morton_code | 4
if octree_knn_search(root.children[morton_code],db,result_set,query):
return True
#找其他子节点
# c:当前的位置,child:当前的结点
for c,child in enumerate(root.children):
if c == morton_code or child is None:
continue
if False == overlaps(query,result_set.worstDist(),child):
continue
if octree_knn_search(child,db,result_set,query):
return True
# 屏蔽结束
# final check of if we can stop search
return inside(query, result_set.worstDist(), root)
# 功能:构建octree,即通过调用octree_recursive_build函数实现对外接口
# 输入:
# dp_np: 原始数据
# leaf_size:scale
# min_extent:最小划分区间
def octree_construction(db_np, leaf_size, min_extent):
N, dim = db_np.shape[0], db_np.shape[1]
db_np_min = np.amin(db_np, axis=0)
db_np_max = np.amax(db_np, axis=0)
db_extent = np.max(db_np_max - db_np_min) * 0.5
db_center = np.mean(db_np, axis=0)
root = None
root = octree_recursive_build(root, db_np, db_center, db_extent, list(range(N)),
leaf_size, min_extent)
return root
def main():
# configuration
db_size = 64000
dim = 3
leaf_size = 4
min_extent = 0.0001
k = 8
db_np = np.random.rand(db_size, dim)
root = octree_construction(db_np, leaf_size, min_extent)
# depth = [0]
# max_depth = [0]
# traverse_octree(root, depth, max_depth)
# print("tree max depth: %d" % max_depth[0])
# query = np.asarray([0, 0, 0])
# result_set = KNNResultSet(capacity=k)
# octree_knn_search(root, db_np, result_set, query)
# print(result_set)
#
# diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
# nn_idx = np.argsort(diff)
# nn_dist = diff[nn_idx]
# print(nn_idx[0:k])
# print(nn_dist[0:k])
begin_t = time.time()
print("Radius search normal:")
for i in range(100):
query = np.random.rand(3)
result_set = RadiusNNResultSet(radius=0.5)
octree_radius_search(root, db_np, result_set, query)
# print(result_set)
print("Search takes %.3fms\n" % ((time.time() - begin_t) * 1000))
begin_t = time.time()
print("Radius search fast:")
for i in range(100):
query = np.random.rand(3)
result_set = RadiusNNResultSet(radius = 0.5)
octree_radius_search_fast(root, db_np, result_set, query)
# print(result_set)
print("Search takes %.3fms\n" % ((time.time() - begin_t)*1000))
if __name__ == '__main__':
main()