三维点云处理(4)——KNN

Nearest Neighbor (NN) Problem

K-NN

  • Given a set of points S S S in a space M M M, a query point q ∈ M q{\in}M qM, find the k k k closest points in S S S

Fixed Radius-NN

  • Given a set of points S S S in a space M M M, a query point q ∈ M q{\in}M qM, find all the points in S S S, s.t., ∣ ∣ s − q ∣ ∣ < r ||s-q||<r sq<r

Core Ideas

  • NN by space partition
  • Stopping criteria: worst distance

Binary Search Tree (BST)

  • Node
// Node
class Node:
	def __init__(self, key, value=-1):
		self.left = None
		self.right = None
		self.key = key
		self.value = value
  • Construction / Insertion
// insert
def insert(root, key, value=-1):
	if root is None:
		root = Node(key, value)
	else:
		if key < root.key:
			root.left = insert(root.left, key, value)
		elif key > root.key:
			root.right = insert(root.right, key, value)
		else: # don't insert if key already exist in the tree
			pass
	return root
  • Search
// search_recursively
def search_recursive(root, key):
	if root is None or root.key == key:
		return root
	if key < root.key:
		return search_recursive(root.left, key)
	elif key > root.key:
		return search_recursive(root.right, key)
// search_iterative 
def search_iterative(root, key):
	current_node = root
	while current_node is not None:
		if current_node.key = key:
			return current_node
		if key < current_node.key:
			current_node = current_node.left
		elif key > current_node.key:
			current_node = current_node.right
	return current_node

Worst Distance for kNN

  • Build a container to store the kNN results
  • k results are sorted
  • worst_dist is the last one
  • Add a result if d i s t < w o r s e d i s t dist < worse_{dist} dist<worsedist
class KNNResultSet:
	def __init__(self, capacity):
		self.capacity = capacity
		self.count = 0
		self.worst_dist = 1e10
		self.dist_index_list = []
		for i in range(capacity):
			self.dist_index_list.append(DistIndex(self.worst_dist, 0))
		self.comparison_counter = 0
	
	def size(self):
		return self.count
	
	def full(self):
		return self.count == self.capacity
	
	def worstDist(self):
		return self.worst_dist
	
	def add_point(self, dist, index):
		self.comparison_counter += 1
		if dist > self.worst_dist:
			return
		
		if self.count < self.capacity:
			self.count += 1
		i = self.count - 1
		while i > 0
			if self.dist_index_list[i-1].distance > dist:
				self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
				i -= 1
			else:
				break
		self.dist_index_list[i].distance = dist
		self.dist_index_list[i].index = index
		self.worst_dist = self.dist_index_list[self,capacity-1].distance
class DistIndex:
	def __init__(self, distance, index):
		self.distance = distance
		self.index = index
	def __lt__(self, other):
		return self.distance < other.distance
def knn_search(root: Node, result_set: KNNResultSet, key):
	if root is None:
		return False
	
	# compare the root itself
	result_set.add_point(math.fabs(root.key - key), root.value)
	if result_set.worstDist() == 0:
		return True
	if root.key >= key:
		# iterate left branch first
		if knn_search(root.left, result_set, key):
			return True
		elif math.fabs(root.key - key) < result_set.worstDist():
			return knn_search(root.right, result_set, key)
		return False
	else:
		# iterate right branch first
		if knn_search(root.right, result_set, key):
			return True
		elif math.fabs(root.key - key) < result_set.worstDist():
			return knn_search(root.left, result_set, key)
		return False

Radius Result Set Manger

def add_point(self, dist, index):
	self.comparison_counter += 1
	if dist > self.radius:
		return
	self.count += 1
	self.dist_index_list.append(DistIndex(dist, index))
def radius_search(root: Node, result_set: RadiusNNResultSet, key):
	if root is None:
		return False
	# compare the root itself
	result_set.add_point(math.fabs(root.key - key), root.value)
	if root.key >= key:
		# iterate left branch first
		if radius_search(root.left, result_set, key):
			return True
		elif math.fabs(root.key - key) < result_set.worstDist():
			return radius_search(root.right, result_set, key)
		return False
	else:
		# iterate right branch first
		if radius_search(root.right, result_set, key):
			return True
		elif math.fabs(root.key - key) < result_set.worstDist():
			return radius_search(root.left, result_set, key)
		return False

Kd-tree

  • It is an extension of BST into high dimension
  • Invented by Jon Louis Bentley in 1975
  • The kd-tree is a binary tree where every leaf node is a k-dimensional point

Division / Splitting Strategy

  • round-robin
  • adaptive

Node Representation

class Node:
	def __init__(self, axis, value, left, right, point_indices):
		self.axis = axis
		self.value = value
		self.left = left
		self.right = right
		self.point_indices = point_indices
	def is_leaf(self):
		if self.value is None:
			return True
		else:
			return False

Recursive Build

def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
	"""
	:param root:
	:param db: N×D
	:param db_sorted_idx_inv: N×D
	:param point_idx: M
	:param axis: scalar
	:param leaf_size: scalar
	:return:
	"""
	if root is None:
		root = Node(axis, None, None, None,point_indices)
	# determine whether to split into left and right
	if len(point_indices) > leaf_size:
		# --- get the split position ---
		# sotr the points in this node, get the median position
		point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis]) # M
		middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1
		middle_left_point_idx = point_indices_sorted[middle_left_idx]
		middle_left_point_value = db[middle_left_point_idx, axis]
		
		middle_right_idx = middle_left_idx + 1
		middle_right_point_idx = point_indices_sorted[middle_right_idx]
		middle_right_point_value = db[middle_right_point_idx, axis]
		
		root.value = (middle_left_point_value + middle_right_point_value) * 0.5
		# === get the split position ===
		root.left = kdtree_recursive_build(root.left,
										   db,
										   point_indices_sorted[0:middle_right_idx],
										   axis_round_robin(axis, dim=db.shape[1]),
										   leaf_size)
		root.right = kdtree_recursive_build(root.right,
										   db,
										   point_indices_sorted[middle_right_idx:],
										   axis_round_robin(axis, dim=db.shape[1]),
										   leaf_size)
	return root
def axis_round_robin(axis, dim):
	if axis == dim - 1:
		return 0
	else:
		return axis + 1

Complexity

  • Time complexity1: select median (subset or mean) instead of sorting
    O ( n log ⁡ n log ⁡ n ) → O ( k n log ⁡ n ) O(n\log{n}\log{n}){\rightarrow}O(kn\log{n}) O(nlognlogn)O(knlogn)
  • Space complexity: Only store points at leaf
    O ( k n + log ⁡ n ) → O ( k n + n ) O(kn+\log{n}){\rightarrow}O(kn+n) O(kn+logn)O(kn+n)

kNN Search

  • Start from root
  • Reach the leaf node than covers the query point
  • Go up and traverse the tree
  • Criteria of a position intersects with the worst-distance range:
    在这里插入图片描述
  1. q [ a x i s ] q[axis] q[axis] inside the partition
  2. q [ a x i s ] − s p l i t t i n g _ v a l u e ∣ < w q[axis]-splitting\_value|<w q[axis]splitting_value<w
def knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query:np.ndarry):
	if root is None:
		return False
	# Compare query to every point inside the leaf, put into the result set
	if root is leaf():
		leaf_points = db[root.point_indices, :]
		diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis = 1)
		# update worstDist()
		for i in range(diff.shape[0]):
			result_set.add_point(diff[i], root.point_indices[i])
		return False
	# q[axis] inside the partition
	if query[root.axis] <= root.value:
		knn_search(root.left, db, result_set, query)
		if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
			knn_search(root.right, db, result_set, query)
	else:
		knn_search(root.right, db, result_set, query)
		if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
			knn_search(root.left, db, result_set, query)
	
	return False
# RadiusNN
if query[root.axis] < root.value
	radius_search(root.left, db, result_set, query)
	if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
		radius_search(root.right, db, result_set, query)
else:
	radius_search(root.right, db, result_set, query)
	if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
		radius_search(root.left, db, result_set, query)
return False

Octree

  • Each node has 8 children
  • Specifically for 3D, 2 3 = 8 2^3=8 23=8
  • In kd-tree, it is non-trivial to determine whether the NN search is done, so we have to go back to root every time
  • Octree is more efficient because we can stop without going back to root

Octree Construction

  • Determine the extent of the first octant
  • Octant is an element in the octree
  • Octant is a cube
  • Determine whether to further split the octant (leaf_size, min_extent)

Octant Representation

class Octant:
	def __init__(self, children, center, extent, point_indices, is_leaf):
		self.children = children	# Array of length 8
		self.center = center	# Center of the cube
		self.extent = extent	# 0.5 * length
		self.point_indices = point_indices	# Point inside octant
		self.is_leaf = is_leaf

Recursive Build

def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent)
	if len(point_indices) == 0:
		return None
	if root is None:
		root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)
	
	# determine whether to split this octant
	if len(point_indices) <= leaf_size or extent <= min_extent:
		root.is_leaf = True
	else:
		root.is_leaf = False
		children_point_indices = [[] for i in range(8)]
		# Determine which child a point belongs to
		for point_idx in point_indices:
			point_db = db[point_idx]
			morton_code = 0
			if point_db[0] < center[0]:
				morton_code = morton_code | 1
			if point_db[1] < center[1]:
				morton_code = morton_code | 2
			if point_db[2] < center[2]:
				morton_code = morton_code | 4
			children_point_indices[morton_code].append(point.idx)
	
	# create children
	factor = [-0.5, 0.5]
	# Determin child center & extent
	for i in range(8):
		child_center_x = center[0] + factor[(i & 1) > 0] * extent
		child_center_y = center[1] + factor[(i & 2) > 0] * extent
		child_center_z = center[0] + factor[(i & 4) > 0] * extent
		child_extent = 0.5 * extent
		child_center = np.asarray([child_center_x, child_center_y, child_center_z])
		root.children[i] = octree_recursive_build(root.children[i],
												  db,
												  child_center,
												  child_extent,
												  children_point_indices[i],
												  leaf_size,
												  min_extent)
	return root

Octree KNN Search

def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
	if root is None:
		return False
	if root.is_leaf and len(root.point_indices) > 0:
		# compare the contents of a leaf
		leaf_points = db[root.point_indices, :]
		diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis = 1)
		for i in range(diff.shape[0]):
			result_set.add_point(diff[i], root.point_indices[i])
		# check whether we can stop search now
		return inside(query, result_set.worstDist(), root)
	# go to the relevant child first
	morton_code = 0
	if query[0] > root.center[0]:
		morton_code = morton_code | 1
	if query[1] > root.center[1]:
		morton_code = morton_code | 2
	if query[2] > root.center[2]:
		morton_code = morton_code | 4
	if octree_knn_search(root.children[morton_code], db, result_set, query):
		return True
	
	# check other children
	for c, child in enumerate(root.children):
		if c == morton_code or child is None:
			continue
		# If an octant is not overlapping with query ball, skip
		if False == overlaps(query, result_set.worstDist(), child):
			continue
		if octree_knn_search(child, db, result_set, query):
			return True
	
	# final check of if we can stop search
	return inside(query, result_set.worstDist(), root)
  • inside()
def inside(query: np.ndarray, radius: float, octant: Octant):
	"""
	Determine if the query ball is inside the octant
	:param query
	:param radius
	:param octant
	:return:
	"""
	query_offset = query - octant.center
	query_offset_abs = np.fabs(query_offset)
	possible_space = query_offset_abs + radius
	return np.all(possible_space < octant.extent)
  • overlaps()
def overlaps(query: np.ndarray, radius: float, octant: Octant):
	"""
	Determines if the query ball overlaps with the octatn
	:param query
	:param radius
	:param octant
	:return:
	"""
	query_offset = query - octant.center
	query_offset_abs = np.fabs(query_offset)
	
	# completely outside, since query is outside the relevant area
	max_dist = radius + octant.extent
	if np.any(query_offset_abs > max_dist):
		return False
	
	# if pass the above check, consider the case that the ball is contacting the face of the octant
	if np.sum((query_offset_abs < octant.extent).astype(np.int)) >= 2:
		return True
	
	# conside the case that the ball is contacting the edge or corner of the octant
	# since the case of the ball center (query) inside octant has been considered,
	# we only consider the ball center (query) outside octant
	x_diff = max (query_offset_abs[0] - octant.extent, 0)
	y_diff = max (query_offset_abs[1] - octant.extent, 0)
	z_diff = max (query_offset_abs[2] - octant.extent, 0)
	
	return x_diff * x_diif + y_diff * y_diff + z_diff * z_diff < radius * radius

radiusNN

  • contains
def contains(query: np.ndarray, radius: float, octant:Octant):
	"""
	Determine if the query ball contains the octant
	:param query:
	:param radius:
	:param octant
	:return:
	"""
	query_offset = query - octant.center
	query_offset_abs = np.fabs(query_offset)
	
	query_offset_to_farthest_corner = query_offset_abs + octant.extent
	return np.linalg.norm(query_offset_to_farthest_corner) < radius

  1. Russel A. Brown, Journal of Computer Graphics Techniques, 2015 ↩︎

  • 1
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值