三维点云处理08-KDTree

三维点云处理08-KDTree

result_set.py
import copy

# distance-index类
class DistIndex:
  def __init__(self,distance,index):
    self.distance = distance
    self.index = index

  def __lt__(self,other):
    return self.distance < other.distance

# 创建用来存放KNN结果的集合
class KNNResultSet:
  def __init__(self,capacity):
    self.capacity = capacity
    self.count = 0
    self.worst_dist = 1e10
    self.dist_index_list = []

    for i in range(self.capacity):
      self.dist_index_list.append(DistIndex(self.worst_dist,0))

    self.comparison_counter = 0

  def size(self):
    return self.count

  def full(self):
    return self.count == self,capacity

  def worstDist(self):
    return self.worst_dist

  def add_point(self,dist,index):
    self.comparison_counter += 1
    if dist > self.worst_dist:
      return

    if self.count < self.capacity:
      self.count += 1

    i = self.count - 1

    while i > 0:
      if self.dist_index_list[i-1].distance > dist:
        self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
        i -= 1
      else:
        break
    self.dist_index_list[i].distance = dist
    self.dist_index_list[i].index = index
    self.worst_dist = self.dist_index_list[self.capacity-1].distance
  
  # 打印输出函数
  def __str__(self):
    output = ''
    for i,dist_index in enumerate(self.dist_index_list):
      output += '%d - %.2f\n' % (dist_index.index,dist_index.distance)
      output += 'In total %d comparison operations.' % self.comparison_counter
    return output

# 创建用来存放RNN结果的集合
class RadiusNNResultSet:
  def __init__(self,radius):
    self.count = 0
    self.comparison_counter = 0
    self.worst_dist = radius
    self.dist_index_list = []

  def size(self):
    return self.count

  def worstDist(self):
    return self.worst_dist

  def add_point(self,dist,index):
    self.comparison_counter += 1
    if dist > self.worst_dist:
      return

    self.count += 1
    self.dist_index_list.append(DistIndex(dist,index))

  # 打印输出函数
  def __str__(self):
    self.dist_index_list.sort()
    output = ''
    for i,dist_index in enumerate(self.dist_index_list):
      output += '%d - %.2f\n' % (dist_index.index,dist_index.distance)
      output += 'In total %d neighbors within %.f \n There are %d comparsion operations' %(self.count,self.worst_dist,self.comparison_counter)
    return output
kdtree.py
import random
import math
import numpy as np

from result_set import KNNResultSet,RadiusNNResultSet

# 创建kdtree节点
class Node:
  def __init__(self,axis,value,left,right,point_indices):
    '''
    axis:表示当前节点基于哪个轴进行空间划分
    value:划分超平面在该轴上的坐标
    left:左子树,也是一个Node
    right:右子树,也是一个Node
    point_indices:当前子树所包含的点在原始数组中的index
    '''
    self.axis = axis
    self.value = value
    self.left = left
    self.right = right
    self.point_indices = point_indices

  # 判断当前节点是否是叶子节点的函数
  # 思考下为什么self.value的值可以用来判断是否是叶子节点
  # 首先一个kdtree节点只包含axis,value,left,right,point_indices五种信息,五种信息排除后发现,只有value可以表示是否为叶子节点
  def is_leaf(self):
    if self.value is None:
      return True
    else:
      return False

  # 创建打印函数
  def __str__(self):
    output = ''
    output += 'axis %d,'% self.axis
    if self.value is None:
      output += 'split value:leaf,'
    else:
      output += 'split value:%.2f,'% self.value
    output += 'point_indices:'
    output += str(self.point_indices.tolist())
    return output

# 根据传入节点值的大小进行排序
def sort_key_by_value(key,value):
  '''
  key:point_indices:点在原始数组中的indices
  value:point_indices对应点在当前axis上的值
  '''
  assert key.shape == value.shape
  assert len(key.shape) == 1
  sorted_idx = np.argsort(value)
  key_sorted = key[sorted_idx]
  value_sorted = value[sorted_idx]
  return key_sorted,value_sorted

# kdtree每个轴轮流进行超平面划分 
def axis_round_robin(axis,dim):
  if axis == dim -1:
    return 0
  else:
    return axis + 1

# 递归建立kdtree
def kdtree_recursive_build(root,db,point_indices,axis,leaf_size):
  '''
  root
  db NxD
  db_sorted_idx_inv:NxD
  point_indices:M
  axis:scalar
  leaf_size:scalar
  '''
  # 传入节点为空,尚未创建,使用Node创建
  if root is None:
    root = Node(axis,None,None,None,point_indices)
  
  # 如果当前节点所包含的点数大于叶子节点要求的点数,需要继续划分
  if len(point_indices) > leaf_size:
    # 根据当前进行超平面划分的轴对point_indices进行排序,获得排序后的point_indices_sorted
    point_indices_sorted,_ = sort_key_by_value(point_indices,db[point_indices,axis])
    middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2)-1
    middle_left_point_idx = point_indices_sorted[middle_left_idx]
    middle_left_point_value = db[middle_left_point_idx,axis]

    middle_right_idx = middle_left_idx + 1
    middle_right_point_idx = point_indices_sorted[middle_right_idx]
    middle_right_point_value = db[middle_right_point_idx,axis]

    root.value = (middle_left_point_value + middle_right_point_value) * 0.5

    root.left = kdtree_recursive_build(root.left,
                                       db,
                                       point_indices_sorted[0:middle_right_idx],
                                       axis_round_robin(axis,dim=db.shape[1]),
                                       leaf_size)
    root.right = kdtree_recursive_build(root.right,
                                        db,
                                        point_indices_sorted[middle_right_idx:],
                                        axis_round_robin(axis,dim=db.shape[1]),
                                        leaf_size)
  return root

# 前序遍历kdtree,同时记录最大深度
def traverse_kdtree(root:Node,depth,max_depth):
  depth[0] += 1
  if max_depth[0] < depth[0]:
    max_depth[0] = depth[0]
  
  if root.is_leaf():
    print(root)
  else:
    traverse_kdtree(root.left,depth,max_depth)
    traverse_kdtree(root.right,depth,max_depth)

  depth[0] -= 1

# 建立kdtree
def kdtree_construction(db_np,leaf_size):
  N,dim = db_np.shape[0],db_np.shape[1]

  # 调用kdtree_recursive_build递归的建立kdtree
  root = None
  root = kdtree_recursive_build(root,
                                db_np,
                                np.arange(N),
                                axis=0,
                                leaf_size=leaf_size)
  return root

# kdtree的k近邻搜索
def kdtree_knn_search(root:Node,db:np.ndarray,result_set:KNNResultSet,query:np.ndarray):
  # 如果当前节点为空,直接返回False
  if root is None:
    return False
  
  # 如果当前节点为叶子节点,对该叶子节点中的所有节点进行比较
  if root.is_leaf():
    leaf_points = db[root.point_indices,:]
    diff = np.linalg.norm(np.expand_dims(query,0)-leaf_points,axis=1)
    for i in range(diff.shape[0]):
      result_set.add_point(diff[i],root.point_indices[i])
    return False
  # 如果查询点在当前axis轴上的坐标小于等于当前节点在该axis上的坐标
  # 先搜索kdtree的左侧
  if query[root.axis] <= root.value:
    kdtree_knn_search(root.left,db,result_set,query)
    if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
      kdtree_knn_search(root.right,db,result_set,query)
  else:
    kdtree_knn_search(root.right,db,result_set,query)
    if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
      kdtree_knn_search(root.left,db,result_set,query)

    return False

# kdtree的Radius近邻搜索
def kdtree_radius_search(root:Node,db:np.ndarray,result_set:RadiusNNResultSet,query:np.ndarray):
  if root is None:
    return False
  
  if root.is_leaf():
    leaf_points = db[root.point_indices,:]
    diff = np.linalg.norm(np.expand_dims(query,0)-leaf_points,axis=1)
    for i in range(diff.shape[0]):
      result_set.add_point(diff[i],root.point_indices[i])
    return False

  if query[root.axis] <= root.value:
    kdtree_radius_search(root.left,db,result_set,query)
    if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
      kdtree_radius_search(root.right,db,result_set,query)
    
  else:
    kdtree_radius_search(root.right,db,result_set,query)
    if math.fabs(query[root.axis]-root.value) < result_set.worstDist():
      kdtree_radius_search(root.left,db,result_set,query)
  
  return False


def main():
  db_size = 64
  dim = 3
  leaf_size = 4
  k = 1

  db_np = np.random.rand(db_size,dim)

  root = kdtree_construction(db_np,leaf_size=leaf_size)

  depth = [0]
  max_depth = [0]
  traverse_kdtree(root,depth,max_depth)
  print("tree max depth:%d" % max_depth[0])

  query = np.asarray([0,0,0])
  result_set = KNNResultSet(capacity=k)
  kdtree_knn_search(root,db_np,result_set,query)

  print(result_set)

  diff = np.linalg.norm(np.expand_dims(query,0)-db_np,axis=1)
  nn_idx = np.argsort(diff)
  nn_dist = diff[nn_idx]
  print(nn_idx[0:k])
  print(nn_dist[0:k])

  print('Radius search')
  query = np.asarray([0,0,0])
  result_set = RadiusNNResultSet(radius=0.5)
  kdtree_radius_search(root,db_np,result_set,query)
  print(result_set)

if __name__ == '__main__':
  main()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值