三维点云处理08-KDTree
result_set.py
import copy
class DistIndex:
def __init__(self,distance,index):
self.distance = distance
self.index = index
def __lt__(self,other):
return self.distance < other.distance
class KNNResultSet:
def __init__(self,capacity):
self.capacity = capacity
self.count = 0
self.worst_dist = 1e10
self.dist_index_list = []
for i in range(self.capacity):
self.dist_index_list.append(DistIndex(self.worst_dist,0))
self.comparison_counter = 0
def size(self):
return self.count
def full(self):
return self.count == self,capacity
def worstDist(self):
return self.worst_dist
def add_point(self,dist,index):
self.comparison_counter += 1
if dist > self.worst_dist:
return
if self.count < self.capacity:
self.count += 1
i = self.count - 1
while i > 0:
if self.dist_index_list[i-1].distance > dist:
self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
i -= 1
else:
break
self.dist_index_list[i].distance = dist
self.dist_index_list[i].index = index
self.worst_dist = self.dist_index_list[self.capacity-1].distance
def __str__(self):
output = ''
for i,dist_index in enumerate(self.dist_index_list):
output += '%d - %.2f\n' % (dist_index.index,dist_index.distance)
output += 'In total %d comparison operations.' % self.comparison_counter
return output
class RadiusNNResultSet:
def __init__(self,radius):
self.count = 0
self.comparison_counter = 0
self.worst_dist = radius
self.dist_index_list = []
def size(self):
return self.count
def worstDist(self):
return self.worst_dist
def add_point(self,dist,index):
self.comparison_counter += 1
if dist > self.worst_dist:
return
self.count += 1
self.dist_index_list.append(DistIndex(dist,index))
def __str__(self):
self.dist_index_list.sort()
output = ''
for i,dist_index in enumerate(self.dist_index_list):
output += '%d - %.2f\n' % (dist_index.index,dist_index.distance)
output += 'In total %d neighbors within %.f \n There are %d comparsion operations' %(self.count,self.worst_dist,self.comparison_counter)
return output
kdtree.py
import random
import math
import numpy as np
from result_set import KNNResultSet,RadiusNNResultSet
class Node:
def __init__(self,axis,value,left,right,point_indices):
'''
axis:表示当前节点基于哪个轴进行空间划分
value:划分超平面在该轴上的坐标
left:左子树,也是一个Node
right:右子树,也是一个Node
point_indices:当前子树所包含的点在原始数组中的index
'''
self.axis = axis
self.value = value
self.left = left
self.right = right
self.point_indices = point_indices
def is_leaf(self):
if self.value is None:
return True
else:
return False
def __str__(self):
output = ''
output += 'axis %d,'% self.axis
if self.value is None:
output += 'split value:leaf,'
else:
output += 'split value:%.2f,'% self.value
output += 'point_indices:'
output += str(self.point_indices.tolist())
return output
def sort_key_by_value(key,value):
'''
key:point_indices:点在原始数组中的indices
value:point_indices对应点在当前axis上的值
'''
assert key.shape == value.shape
assert len(key.shape) == 1
sorted_idx = np.argsort(value)
key_sorted = key[sorted_idx]
value_sorted = value[sorted_idx]
return key_sorted,value_sorted
def axis_round_robin(axis,dim):
if axis == dim -1:
return 0
else:
return axis + 1
def kdtree_recursive_build(root,db,point_indices,axis,leaf_size):
'''
root
db NxD
db_sorted_idx_inv:NxD
point_indices:M
axis:scalar
leaf_size:scalar
'''
if root is None:
root = Node(axis,None,None,None,point_indices)
if len(point_indices) > leaf_size:
point_indices_sorted,_ = sort_key_by_value(point_indices,db[point_indices,axis])
middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2)-1
middle_left_point_idx = point_indices_sorted[middle_left_idx]
middle_left_point_value = db[middle_left_point_idx,axis]
middle_right_idx = middle_left_idx + 1
middle_right_point_idx = point_indices_sorted[middle_right_idx]
middle_right_point_value = db[middle_right_point_idx,axis]
root.value = (middle_left_point_value + middle_right_point_value) * 0.5
root.left = kdtree_recursive_build(root.left,
db,
point_indices_sorted[0:middle_right_idx],
axis_round_robin(axis,dim=db.shape[1]),
leaf_size)
root.right = kdtree_recursive_build(root.right,
db,
point_indices_sorted[middle_right_idx:],
axis_round_robin(axis,dim=db.shape[1]),
leaf_size)
return root
def traverse_kdtree(root:Node,depth,max_depth):
depth[0] += 1
if max_depth[0] < depth[0]:
max_depth[0] = depth[0]
if root.is_leaf():
print(root)
else:
traverse_kdtree(root.left,depth,max_depth)
traverse_kdtree(root.right,depth,max_depth)
depth[0] -= 1
def kdtree_construction(db_np,leaf_size):
N,dim = db_np.shape[0],db_np.shape[1]
root = None
root = kdtree_recursive_build(root,
db_np,
np.arange(N),
axis=0,
leaf_size=leaf_size)
return root
def kdtree_knn_search(root:Node,db:np.ndarray,result_set:KNNResultSet,query:np.ndarray):
if root is None:
return False
if root.is_leaf():
leaf_points = db[root.point_indices,:]
diff = np.linalg.norm(np.expand_dims(query,0)-leaf_points,axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i],root.point_indices[i])
return False
if query[root.axis] <= root.value:
kdtree_knn_search(root.left,db,result_set,query)
if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
kdtree_knn_search(root.right,db,result_set,query)
else:
kdtree_knn_search(root.right,db,result_set,query)
if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
kdtree_knn_search(root.left,db,result_set,query)
return False
def kdtree_radius_search(root:Node,db:np.ndarray,result_set:RadiusNNResultSet,query:np.ndarray):
if root is None:
return False
if root.is_leaf():
leaf_points = db[root.point_indices,:]
diff = np.linalg.norm(np.expand_dims(query,0)-leaf_points,axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i],root.point_indices[i])
return False
if query[root.axis] <= root.value:
kdtree_radius_search(root.left,db,result_set,query)
if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
kdtree_radius_search(root.right,db,result_set,query)
else:
kdtree_radius_search(root.right,db,result_set,query)
if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
kdtree_radius_search(root.left,db,result_set,query)
return False
def main():
db_size = 64
dim = 3
leaf_size = 4
k = 1
db_np = np.random.rand(db_size,dim)
root = kdtree_construction(db_np,leaf_size=leaf_size)
depth = [0]
max_depth = [0]
traverse_kdtree(root,depth,max_depth)
print("tree max depth:%d" % max_depth[0])
query = np.asarray([0,0,0])
result_set = KNNResultSet(capacity=k)
kdtree_knn_search(root,db_np,result_set,query)
print(result_set)
diff = np.linalg.norm(np.expand_dims(query,0)-db_np,axis=1)
nn_idx = np.argsort(diff)
nn_dist = diff[nn_idx]
print(nn_idx[0:k])
print(nn_dist[0:k])
print('Radius search')
query = np.asarray([0,0,0])
result_set = RadiusNNResultSet(radius=0.5)
kdtree_radius_search(root,db_np,result_set,query)
print(result_set)
if __name__ == '__main__':
main()