三维点云学习(2)五种算法比较
代码参考来自 黎老师github
本次测试包含五种算法比较:
octree
print("octree --------------")
#时间统计
constructiontimesum = 0
knntimesum = 0
radiustimesum = 0 brutetimesum = 0 # #construction begint = time.time() root = octree.octree_construction(db_np, leaf_size, min_extent) #构建octree construction_time_sum += time.time() - begin_t #统计构建时间 #Octree KNNsearch begin_t = time.time() for i in range(len(db_np)): #len(db_np) query = db_np[i,:] #对每一个点进行KNN搜索 result_set = KNNResultSet(capacity=k) octree.octree_knn_search(root, db_np, result_set, query) knn_time_sum += time.time() - begin_t print("Octree: build %.3fms, knn %.3fms" %(construction_time_sum1000,knn_time_sum1000))
spatial_kdtree
调用spatial的kdtree进行临近点查找
#spatial.KDTree
construction_time_sum = 0
knn_time_sum = 0
#construction
begin_t = time.time()
tree = spatial.KDTree(db_np)
construction_time_sum += time.time() - begin_t
#search
begin_t = time.time()
tree.query(db_np,k=8)
knn_time_sum += time.time() - begin_t
print("Kdtree_spatial: build %.3fms, knn %.3fms" % (construction_time_sum 1000, knn_time_sum 1000))
origin_kdtree
Origin为实用老师的git kdtree.py代码,空间建轴方式为xyz顺序建轴
construction_time_sum = 0
knn_time_sum = 0
radius_time_sum = 0
brute_time_sum = 0
result_set = KNNResultSet(capacity=k)
#origin KDtree 顺序建轴
begin_t = time.time()
root = kdtree.kdtree_construction(db_np, leaf_size)
construction_time_sum += time.time() - begin_t
begin_t = time.time()
for i in range(len(db_np)): #len(db_np)
query = db_np[i,:]
kdtree.kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("Kdtree_Origin: build %.3fms, knn %.3fms" % (construction_time_sum 1000, knn_time_sum 1000))
new_kdtree
new为使用方差建轴建轴
关于方差建轴一些理论的讨论和见解如下:
kdtree划分空间维度选择使用“最大方差法”的好处
附上老师的答疑截图
#new KNN 方差建轴
construction_time_sum = 0
knn_time_sum = 0
radius_time_sum = 0
brute_time_sum = 0
result_set = KNNResultSet(capacity=k)
begin_t = time.time()
root = kdtree_new.kdtree_construction(db_np, leaf_size)
construction_time_sum += time.time() - begin_t
begin_t = time.time()
for i in range(len(db_np)): #len(db_np)
query = db_np[i,:]
kdtree_new.kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("Kdtree_New: build %.3fms, knn %.3fms" % (construction_time_sum 1000, knn_time_sum 1000))
方差建轴的具体实现函数如下
def axis_select(leaf_point):
# arr_var_x = np.var(leaf_point[:,0])
# arr_var_y = np.var(leaf_point[:,1])
# arr_var_z = np.var(leaf_point[:,2])
# arr_axis_max = max(arr_var_x,arr_var_y,arr_var_z)
# if( arr_axis_max == arr_var_x):
# #print("axis = 0 ")
# return 0
# elif ( arr_axis_max == arr_var_y):
# #print("axis = 1")
# return 1
# else:
# #print("axis = 2")
# return 2
arr_var = np.var(leaf_point,axis=0) #求方差
arr_axis_max = max(arr_var[0],arr_var[1],arr_var[2]) #选取方差较大的进行轴进行切割
if( arr_axis_max == arr_var[0]):
return 0 #axis= 0
elif ( arr_axis_max == arr_var[1]):
return 1 #axis = 1
else:
return 2 #axis = 2
调用
Brute
暴力排序:
print("Brute --------------")
brute_time_sum = 0
brute_time_sum = 0
begin_t = time.time()
for i in range(1000):
query = db_np[i, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
nn_idx = np.argsort(diff)
nn_dist = diff[nn_idx]
brute_time_sum += time.time() - begin_t
print("1000 points for : brute %.3fms" % (brute_time_sum1000))
测试结果
对数据集中12W个点对前1000个点分别进行8-NN搜寻结果如下:
对数据集中12W个点对每点分别进行8-NN搜寻结果如下:
完整代码:
benchamark.py
# 对数据集中的点云,批量执行构建树和查找,包括kdtree和octree,并评测其运行时间
import random
import math
import numpy as np
import time
import os
import struct
import octree as octree
import kdtree as kdtree
import kdtree_new as kdtree_new
from result_set import KNNResultSet, RadiusNNResultSet
from scipy import spatial
def read_velodyne_bin(path):
'''
:param path:
:return: homography matrix of the point cloud, N*3
'''
pc_list = []
with open(path, 'rb') as f:
content = f.read()
pc_iter = struct.iter_unpack('ffff', content)
for idx, point in enumerate(pc_iter):
pc_list.append([point[0], point[1], point[2]])
return np.asarray(pc_list, dtype=np.float32)
def main():
# configuration
leaf_size = 32 #每个leaf最多有32个点
min_extent = 0.0001 #octant的最小尺寸
k = 8 #搜索8个紧临点
radius = 20 #radius NN的半径为1
#读取文件
filename = 'data/000000.bin'
db_np = read_velodyne_bin(filename)
print("octree --------------")
#时间统计
construction_time_sum = 0
knn_time_sum = 0
radius_time_sum = 0
brute_time_sum = 0
#construction
begin_t = time.time()
root = octree.octree_construction(db_np, leaf_size, min_extent)
construction_time_sum += time.time() - begin_t # 统计构建时间
#Octree KNNsearch
begin_t = time.time()
for i in range(len(db_np)): #len(db_np)
result_set = KNNResultSet(capacity=k)
query = db_np[i,:] #对每一个点进行KNN搜索
octree.octree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("Octree: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum * 1000))
#
# #Octree Radiusearch
# # begin_t = time.time()
# # #for i in range(len(db_np)):
# # for i in range(1000):
# # query = db_np[i,:] #对每一个点进行KNN搜索
# # result_set = RadiusNNResultSet(radius=radius)
# # octree.octree_radius_search_fast(root, db_np, result_set, query)
# # radius_time_sum += time.time() - begin_t
#
# #brute search
# begin_t = time.time()
# #for i in range(len(db_np)):
# for i in range(1):
# query = db_np[i,:] #对每一个点进行KNN搜索
# diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1) #query的shape为(3,) db_np的shape为(124668,3),需要使用expand_dims()或者reshape 改为(1,3)
# nn_idx = np.argsort(diff)
# nn_dist = diff[nn_idx]
# brute_time_sum += time.time() - begin_t
# # print("Octree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum*1000/iteration_num,
# # knn_time_sum*1000/iteration_num,
# # radius_time_sum*1000/iteration_num,
# # brute_time_sum*1000/iteration_num))
# print("Octree: build %.3fms, knn %.3fms, brute %.3fms" % (construction_time_sum*1000,
# knn_time_sum*1000,
# brute_time_sum*1000))
print("kdtree --------------")
#spatial.KDTree
construction_time_sum = 0
knn_time_sum = 0
#construction
begin_t = time.time()
tree = spatial.KDTree(db_np,leaf_size)
construction_time_sum += time.time() - begin_t
#search
begin_t = time.time()
tree.query(x=db_np[0:30000,:],k=8)
knn_time_sum += time.time() - begin_t
print("Kdtree_spatial: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum * 1000))
construction_time_sum = 0
knn_time_sum = 0
radius_time_sum = 0
brute_time_sum = 0
#origin KDtree 顺序建轴
begin_t = time.time()
root = kdtree.kdtree_construction(db_np, leaf_size)
construction_time_sum += time.time() - begin_t
begin_t = time.time()
for i in range(len(db_np)): #len(db_np)
result_set = KNNResultSet(capacity=k)
query = db_np[i,:]
kdtree.kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("Kdtree_Origin: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum * 1000))
#new KNN 方差建轴
construction_time_sum = 0
knn_time_sum = 0
begin_t = time.time()
root = kdtree_new.kdtree_construction(db_np, leaf_size)
construction_time_sum += time.time() - begin_t
begin_t = time.time()
for i in range(len(db_np)): #len(db_np)
result_set = KNNResultSet(capacity=k)
query = db_np[i,:]
kdtree_new.kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("Kdtree_New: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum *1000))
# begin_t = time.time()
# result_set = RadiusNNResultSet(radius=radius)
# kdtree.kdtree_radius_search(root, db_np, result_set, query)
# radius_time_sum += time.time() - begin_t
#
print("Brute --------------")
brute_time_sum = 0
brute_time_sum = 0
begin_t = time.time()
for i in range(1000):
query = db_np[i, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
nn_idx = np.argsort(diff)
nn_dist = diff[nn_idx]
brute_time_sum += time.time() - begin_t
print("1000 points for : brute %.3fms" % (brute_time_sum*1000))
# print("Kdtree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum * 1000 ,
# knn_time_sum * 1000 ,
# radius_time_sum * 1000 ,
# brute_time_sum * 1000 )
if __name__ == '__main__':
main()
kdtree.py
# kdtree的具体实现,包括构建和查找
import random
import math
import numpy as np
import time
from result_set import KNNResultSet, RadiusNNResultSet
# Node类,Node是tree的基本组成元素
class Node:
def __init__(self, axis, value, left, right, point_indices):
self.axis = axis
self.value = value
self.left = left
self.right = right
self.point_indices = point_indices
def is_leaf(self):
if self.value is None:
return True
else:
return False
def __str__(self):
output = ''
output += 'axis %d, ' % self.axis
if self.value is None:
output += 'split value: leaf, '
else:
output += 'split value: %.2f, ' % self.value
output += 'point_indices: '
output += str(self.point_indices.tolist())
return output
# 功能:构建树之前需要对value进行排序,同时对一个的key的顺序也要跟着改变
# 输入:
# key:键
# value:值
# 输出:
# key_sorted:排序后的键
# value_sorted:排序后的值
def sort_key_by_vale(key, value):
assert key.shape == value.shape #assert 断言操作,用于判断一个表达式,在表达式条件为false的时候触发异常
assert len(key.shape) == 1 #numpy是多维数组
sorted_idx = np.argsort(value) #对value值进行排序
key_sorted = key[sorted_idx]
value_sorted = value[sorted_idx] #进行升序排序
return key_sorted, value_sorted
def axis_round_robin(axis, dim):
if axis == dim-1:
return 0
else:
return axis + 1
# 功能:通过递归的方式构建树
# 输入:
# root: 树的根节点
# db: 点云数据
# point_indices:排序后的键
# axis: scalar
# leaf_size: scalar
# 输出:
# root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
if root is None:
root = Node(axis, None, None, None, point_indices)
# determine whether to split into left and right
if len(point_indices) > leaf_size:
# --- get the split position ---
point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis]) # point_indices_sorted通过axis排序后的key,dp[point_indices,axis]提取当前axis下的点
# 作业1
# 屏蔽开始
left_idx = math.ceil(point_indices_sorted.shape[0] / 2) # ceil()函数用于从上取整 计算出左边有多少个点
left_point_idx = point_indices_sorted[left_idx - 1] # 左边节点 的最大值
left_point_value = db[left_point_idx - 1, axis] # 提取值
right_idx = left_idx # 右边的点数
right_point_idx = point_indices_sorted[right_idx] # 右边节点 的最小值
right_point_value = db[right_point_idx, axis] # 提取值
root.value = (right_point_value + left_point_value) * 0.5 # 取middle为 root的值
#进行递归分割
#小值放左边
root.left = kdtree_recursive_build(root.left,
db,
point_indices_sorted[0:right_idx],
axis_round_robin(axis,dim=db.shape[1]),
#axis_select(db[point_indices_sorted[0:right_idx]]),
leaf_size)
#大值放右边
root.right = kdtree_recursive_build(root.right,
db,
point_indices_sorted[right_idx:],
axis_round_robin(axis,dim=db.shape[1]),
#axis_select(db[point_indices_sorted[right_idx:]]),
leaf_size)
# 屏蔽结束
return root
# 功能:翻转一个kd树
# 输入:
# root:kd树
# depth: 当前深度
# max_depth:最大深度
def traverse_kdtree(root: Node, depth, max_depth):
depth[0] += 1
if max_depth[0] < depth[0]:
max_depth[0] = depth[0]
if root.is_leaf():
print(root)
else:
traverse_kdtree(root.left, depth, max_depth)
traverse_kdtree(root.right, depth, max_depth)
depth[0] -= 1
# 功能:构建kd树(利用kdtree_recursive_build功能函数实现的对外接口)
# 输入:
# db_np:原始数据
# leaf_size:scale
# 输出:
# root:构建完成的kd树
def kdtree_construction(db_np, leaf_size):
N, dim = db_np.shape[0], db_np.shape[1]
# build kd_tree recursively
root = None
root = kdtree_recursive_build(root,
db_np,
np.arange(N),
axis = 0, #axis = axis_select(db_np) or axis = 0
leaf_size=leaf_size)
return root
# 功能:通过kd树实现knn搜索,即找出最近的k个近邻
# 输入:
# root: kd树
# db: 原始数据
# result_set:搜索结果
# query:索引信息
# 输出:
# 搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf(): #如果搜索到是叶子节点,直接进行暴力搜索
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1) #求距离
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# 作业2
# 提示:仍通过递归的方式实现搜索
# 屏蔽开始
if query[root.axis] <= root.value: #query[root.axis] 当前目标点的在对应axis上的值 < 当前主节点的value 对left进行搜索
kdtree_knn_search(root.left, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist(): #主节点的值 与 目标值 差值小于worst_dist 要对右边进行搜寻
kdtree_knn_search(root.right, db, result_set, query)
else:
kdtree_knn_search(root.right, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist(): # 与上相反
kdtree_knn_search(root.left, db, result_set, query)
# 屏蔽结束
return False
# 功能:通过kd树实现radius搜索,即找出距离radius以内的近邻
# 输入:
# root: kd树
# db: 原始数据
# result_set:搜索结果
# query:索引信息
# 输出:
# 搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf():
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# 作业3
# 提示:通过递归的方式实现搜索
# 屏蔽开始
if query[root.axis] <= root.value:
kdtree_radius_search(root.left, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_radius_search(root.right, db, result_set, query)
else:
kdtree_radius_search(root.right, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_radius_search(root.left, db, result_set, query)
# 屏蔽结束
return False
def main():
construction_time_sum = 0
knn_time_sum = 0
# configuration
db_size = 640000
dim = 3
leaf_size = 4
k = 8
db_np = np.random.rand(db_size, dim)
#construction
begin_t = time.time()
root = kdtree_construction(db_np, leaf_size=leaf_size)
construction_time_sum += time.time() - begin_t
depth = [0]
max_depth = [0]
traverse_kdtree(root, depth, max_depth)
print("tree max depth: %d" % max_depth[0])
result_set = KNNResultSet(capacity=k)
#query = np.asarray([0, 0, 0])
begin_t = time.time()
for i in range(1):
query = db_np[i,:]
#kdtree search
kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("buile %sms KNN %sms" %(construction_time_sum*1000,knn_time_sum*1000))
#
# print(result_set)
#
# diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
# nn_idx = np.argsort(diff)
# nn_dist = diff[nn_idx]
# print(nn_idx[0:k])
# print(nn_dist[0:k])
#
#
# print("Radius search:")
# query = np.asarray([0, 0, 0])
# result_set = RadiusNNResultSet(radius = 0.5)
# radius_search(root, db_np, result_set, query)
# print(result_set)
if __name__ == '__main__':
main()
kdtree_new.py
# kdtree的具体实现,包括构建和查找
import random
import math
import numpy as np
import time
from result_set import KNNResultSet, RadiusNNResultSet
# Node类,Node是tree的基本组成元素
class Node:
def __init__(self, axis, value, left, right, point_indices):
self.axis = axis
self.value = value
self.left = left
self.right = right
self.point_indices = point_indices
def is_leaf(self):
if self.value is None:
return True
else:
return False
def __str__(self):
output = ''
output += 'axis %d, ' % self.axis
if self.value is None:
output += 'split value: leaf, '
else:
output += 'split value: %.2f, ' % self.value
output += 'point_indices: '
output += str(self.point_indices.tolist())
return output
# 功能:构建树之前需要对value进行排序,同时对一个的key的顺序也要跟着改变
# 输入:
# key:键
# value:值
# 输出:
# key_sorted:排序后的键
# value_sorted:排序后的值
def sort_key_by_vale(key, value):
assert key.shape == value.shape #assert 断言操作,用于判断一个表达式,在表达式条件为false的时候触发异常
assert len(key.shape) == 1 #numpy是多维数组
sorted_idx = np.argsort(value) #对value值进行排序
key_sorted = key[sorted_idx]
value_sorted = value[sorted_idx] #进行升序排序
return key_sorted, value_sorted
def axis_round_robin(axis, dim):
if axis == dim-1:
return 0
else:
return axis + 1
def axis_select(leaf_point):
# arr_var_x = np.var(leaf_point[:,0])
# arr_var_y = np.var(leaf_point[:,1])
# arr_var_z = np.var(leaf_point[:,2])
# arr_axis_max = max(arr_var_x,arr_var_y,arr_var_z)
# if( arr_axis_max == arr_var_x):
# #print("axis = 0 ")
# return 0
# elif ( arr_axis_max == arr_var_y):
# #print("axis = 1")
# return 1
# else:
# #print("axis = 2")
# return 2
arr_var = np.var(leaf_point,axis=0) #求方差
arr_axis_max = max(arr_var[0],arr_var[1],arr_var[2]) #选取方差较大的进行轴进行切割
if( arr_axis_max == arr_var[0]):
return 0 #axis= 0
elif ( arr_axis_max == arr_var[1]):
return 1 #axis = 1
else:
return 2 #axis = 2
# 功能:通过递归的方式构建树
# 输入:
# root: 树的根节点
# db: 点云数据
# point_indices:排序后的键
# axis: scalar
# leaf_size: scalar
# 输出:
# root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
if root is None:
root = Node(axis, None, None, None, point_indices)
# determine whether to split into left and right
if len(point_indices) > leaf_size:
# --- get the split position ---
point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis]) # point_indices_sorted通过axis排序后的key,dp[point_indices,axis]提取当前axis下的点
# 作业1
# 屏蔽开始
left_idx = math.ceil(point_indices_sorted.shape[0] / 2) # ceil()函数用于从上取整 计算出左边有多少个点
left_point_idx = point_indices_sorted[left_idx - 1] # 左边节点 的最大值
left_point_value = db[left_point_idx - 1, axis] # 提取值
right_idx = left_idx # 右边的点数
right_point_idx = point_indices_sorted[right_idx] # 右边节点 的最小值
right_point_value = db[right_point_idx, axis] # 提取值
root.value = (right_point_value + left_point_value) * 0.5 # 取middle为 root的值
#进行递归分割
#小值放左边
root.left = kdtree_recursive_build(root.left,
db,
point_indices_sorted[0:right_idx],
#axis_round_robin(axis,dim=db.shape[1]),
axis_select(db[point_indices_sorted[0:right_idx]]),
leaf_size)
#大值放右边
root.right = kdtree_recursive_build(root.right,
db,
point_indices_sorted[right_idx:],
#axis_round_robin(axis,dim=db.shape[1]),
axis_select(db[point_indices_sorted[right_idx:]]),
leaf_size)
# 屏蔽结束
return root
# 功能:翻转一个kd树
# 输入:
# root:kd树
# depth: 当前深度
# max_depth:最大深度
def traverse_kdtree(root: Node, depth, max_depth):
depth[0] += 1
if max_depth[0] < depth[0]:
max_depth[0] = depth[0]
if root.is_leaf():
print(root)
else:
traverse_kdtree(root.left, depth, max_depth)
traverse_kdtree(root.right, depth, max_depth)
depth[0] -= 1
# 功能:构建kd树(利用kdtree_recursive_build功能函数实现的对外接口)
# 输入:
# db_np:原始数据
# leaf_size:scale
# 输出:
# root:构建完成的kd树
def kdtree_construction(db_np, leaf_size):
N, dim = db_np.shape[0], db_np.shape[1]
# build kd_tree recursively
root = None
root = kdtree_recursive_build(root,
db_np,
np.arange(N),
axis = axis_select(db_np), #axis = axis_select(db_np) or axis = 0
leaf_size=leaf_size)
return root
# 功能:通过kd树实现knn搜索,即找出最近的k个近邻
# 输入:
# root: kd树
# db: 原始数据
# result_set:搜索结果
# query:索引信息
# 输出:
# 搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf(): #如果搜索到是叶子节点,直接进行暴力搜索
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1) #求距离
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# 作业2
# 提示:仍通过递归的方式实现搜索
# 屏蔽开始
if query[root.axis] <= root.value: #query[root.axis] 当前目标点的在对应axis上的值 < 当前主节点的value 对left进行搜索
kdtree_knn_search(root.left, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist(): #主节点的值 与 目标值 差值小于worst_dist 要对右边进行搜寻
kdtree_knn_search(root.right, db, result_set, query)
else:
kdtree_knn_search(root.right, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist(): # 与上相反
kdtree_knn_search(root.left, db, result_set, query)
# 屏蔽结束
return False
# 功能:通过kd树实现radius搜索,即找出距离radius以内的近邻
# 输入:
# root: kd树
# db: 原始数据
# result_set:搜索结果
# query:索引信息
# 输出:
# 搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf():
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# 作业3
# 提示:通过递归的方式实现搜索
# 屏蔽开始
if query[root.axis] <= root.value:
kdtree_radius_search(root.left, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_radius_search(root.right, db, result_set, query)
else:
kdtree_radius_search(root.right, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
kdtree_radius_search(root.left, db, result_set, query)
# 屏蔽结束
return False
def main():
construction_time_sum = 0
knn_time_sum = 0
# configuration
db_size = 640000
dim = 3
leaf_size = 4
k = 8
db_np = np.random.rand(db_size, dim)
#construction
begin_t = time.time()
root = kdtree_construction(db_np, leaf_size=leaf_size)
construction_time_sum += time.time() - begin_t
depth = [0]
max_depth = [0]
traverse_kdtree(root, depth, max_depth)
print("tree max depth: %d" % max_depth[0])
result_set = KNNResultSet(capacity=k)
#query = np.asarray([0, 0, 0])
begin_t = time.time()
for i in range(1):
query = db_np[i,:]
#kdtree search
kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t
print("buile %sms KNN %sms" %(construction_time_sum*1000,knn_time_sum*1000))
#
# print(result_set)
#
# diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
# nn_idx = np.argsort(diff)
# nn_dist = diff[nn_idx]
# print(nn_idx[0:k])
# print(nn_dist[0:k])
#
#
# print("Radius search:")
# query = np.asarray([0, 0, 0])
# result_set = RadiusNNResultSet(radius = 0.5)
# radius_search(root, db_np, result_set, query)
# print(result_set)
if __name__ == '__main__':
main()
result_set.py
# 该文件定义了在树中查找数据所需要的数据结构,类似一个中间件
import copy
class DistIndex:
def __init__(self, distance, index):
self.distance = distance
self.index = index
def __lt__(self, other):
return self.distance < other.distance
class KNNResultSet:
def __init__(self, capacity):
self.capacity = capacity
self.count = 0
self.worst_dist = 1e10
self.dist_index_list = []
for i in range(capacity):
self.dist_index_list.append(DistIndex(self.worst_dist, 0))
self.comparison_counter = 0
def size(self):
return self.count
def full(self):
return self.count == self.capacity
def worstDist(self):
return self.worst_dist
def add_point(self, dist, index):
self.comparison_counter += 1
if dist > self.worst_dist:
return
if self.count < self.capacity:
self.count += 1
i = self.count - 1
while i > 0:
if self.dist_index_list[i-1].distance > dist:
self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
i -= 1
else:
break
self.dist_index_list[i].distance = dist
self.dist_index_list[i].index = index
self.worst_dist = self.dist_index_list[self.capacity-1].distance
def __str__(self):
output = ''
for i, dist_index in enumerate(self.dist_index_list):
output += '%d - %.2f\n' % (dist_index.index, dist_index.distance)
output += 'In total %d comparison operations.' % self.comparison_counter
return output
class RadiusNNResultSet:
def __init__(self, radius):
self.radius = radius
self.count = 0
self.worst_dist = radius
self.dist_index_list = []
self.comparison_counter = 0
def size(self):
return self.count
def worstDist(self):
return self.radius
def add_point(self, dist, index):
self.comparison_counter += 1
if dist > self.radius:
return
self.count += 1
self.dist_index_list.append(DistIndex(dist, index))
def __str__(self):
self.dist_index_list.sort()
output = ''
for i, dist_index in enumerate(self.dist_index_list):
output += '%d - %.2f\n' % (dist_index.index, dist_index.distance)
output += 'In total %d neighbors within %f.\nThere are %d comparison operations.' \
% (self.count, self.radius, self.comparison_counter)
return output