kdtree
1. 建立树
递归建立:
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
if root is None:
root = Node(axis, None, None, None, point_indices)
# determine whether to split into left and right
if len(point_indices) > leaf_size:
# --- get the split position ---
point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis]) # M
middle_idx = math.ceil(len(point_indices) / 2.0)
new_axis = (axis + 1) % db.shape[1]
root.axis = axis
root.value = (db[point_indices_sorted[middle_idx], axis] + db[point_indices_sorted[middle_idx+1], axis]) / 2.0
root.point_indices = point_indices
root.left = kdtree_recursive_build(None, db, point_indices_sorted[0:middle_idx], new_axis, leaf_size)
root.right = kdtree_recursive_build(None, db, point_indices_sorted[middle_idx:], new_axis, leaf_size)
return root
2. 查询
knn search 递归实现方式:
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf():
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
if query[root.axis] < root.value:
kdtree_knn_search(root.left, db, result_set, query)
if np.abs(root.value - query[root.axis]) < result_set.worst_dist:
kdtree_knn_search(root.right, db, result_set, query)
else:
kdtree_knn_search(root.right, db, result_set, query)
if np.abs(root.value - query[root.axis]) < result_set.worst_dist:
kdtree_knn_search(root.left, db, result_set, query)
return False
radius search跟上面一样;
Octree
1. 建立树
def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent):
if len(point_indices) == 0:
return None
if root is None:
root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)
# determine whether to split this octant
if len(point_indices) <= leaf_size or extent <= min_extent:
root.is_leaf = True
else:
root.is_leaf = False
children_point_incidecs = [[] for i in range(8)]
for point_idx in point_indices:
point_db = db[point_idx]
motion_code = 0
if point_db[0] > center[0]:
motion_code = motion_code | 1
if point_db[1] > center[1]:
motion_code = motion_code | 2
if point_db[2] > center[2]:
motion_code = motion_code | 4
children_point_incidecs[motion_code].append(point_idx)
# create children
factor = [-0.5, 0.5]
for i in range(8):
child_center_x = center[0] + factor[(i & 1) > 0] * extent
child_center_y = center[1] + factor[(i & 2) > 0] * extent
child_center_z = center[2] + factor[(i & 4) > 0] * extent
child_center = np.array([child_center_x, child_center_y, child_center_z])
child_extent = extent * 0.5
root.children[i] = octree_recursive_build(root.children[i],
db, child_center, child_extent, children_point_incidecs[i], leaf_size, min_extent)
return root
2. 查询
knn search 查询。为了快速第缩小搜索范围,需要从查询所在的区域开始:
def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# go to the relevant child first
motion_code = 0
if query[0] > root.center[0]:
motion_code = motion_code | 1
if query[1] > root.center[1]:
motion_code = motion_code | 2
if query[2] > root.center[2]:
motion_code = motion_code | 4
if octree_knn_search(root.children[motion_code], db, result_set, query):
return True
# go other area
for c, child in enumerate(root.children):
if c == motion_code or child is None:
continue
if False == overlaps(query, result_set.worst_dist, child):
continue
if octree_knn_search(child, db, result_set, query):
return True
# final check of if we can stop search
return inside(query, result_set.worstDist(), root
为了加速,radius search可以增加contain的判断:
def octree_radius_search_fast(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis = 1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return inside(query, result_set.worst_dist, root)
if contains(query, result_set.worst_dist, root):
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis = 1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
for c, child in enumerate(root.children):
if child is None:
continue
if False == overlaps(query, result_set.worst_dist, child):
continue
if octree_radius_search_fast(child, db, result_set, query):
return True
return inside(query, result_set.worstDist(), root
3. 实验结果
Octree建立需要花费很多时间,但它查询比较快,相比较而言,radius查询会比knn查询快一些。
完整的代码见链接