三维点云学习(2)五种算法比较

三维点云学习(2)五种算法比较

代码参考来自 黎老师github
本次测试包含五种算法比较:

octree

    print("octree --------------")
    #时间统计
    constructiontimesum = 0
    knntimesum = 0
    radiustimesum = 0 brutetimesum = 0 # #construction begint = time.time() root = octree.octree_construction(db_np, leaf_size, min_extent) #构建octree construction_time_sum += time.time() - begin_t #统计构建时间 #Octree KNNsearch begin_t = time.time() for i in range(len(db_np)): #len(db_np) query = db_np[i,:] #对每一个点进行KNN搜索 result_set = KNNResultSet(capacity=k) octree.octree_knn_search(root, db_np, result_set, query) knn_time_sum += time.time() - begin_t print("Octree: build %.3fms, knn %.3fms" %(construction_time_sum1000,knn_time_sum1000)) 

spatial_kdtree

调用spatial的kdtree进行临近点查找

    #spatial.KDTree
    construction_time_sum = 0
    knn_time_sum = 0
    #construction
    begin_t = time.time()
    tree = spatial.KDTree(db_np)
    construction_time_sum += time.time() - begin_t
    #search
    begin_t = time.time()

    tree.query(db_np,k=8)
    knn_time_sum += time.time() - begin_t
    print("Kdtree_spatial: build %.3fms, knn %.3fms" % (construction_time_sum  1000, knn_time_sum  1000))

origin_kdtree

Origin为实用老师的git kdtree.py代码,空间建轴方式为xyz顺序建轴

    construction_time_sum = 0
    knn_time_sum = 0
    radius_time_sum = 0
    brute_time_sum = 0
    result_set = KNNResultSet(capacity=k)

    #origin KDtree  顺序建轴
    begin_t = time.time()
    root = kdtree.kdtree_construction(db_np, leaf_size)
    construction_time_sum += time.time() - begin_t

    begin_t = time.time()
    for i in range(len(db_np)):   #len(db_np)
        query = db_np[i,:]

        kdtree.kdtree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("Kdtree_Origin: build %.3fms, knn %.3fms" % (construction_time_sum  1000, knn_time_sum  1000))

new_kdtree

new为使用方差建轴建轴
关于方差建轴一些理论的讨论和见解如下:
kdtree划分空间维度选择使用“最大方差法”的好处
附上老师的答疑截图
在这里插入图片描述

    #new KNN   方差建轴
    construction_time_sum = 0
    knn_time_sum = 0
    radius_time_sum = 0
    brute_time_sum = 0
    result_set = KNNResultSet(capacity=k)
    begin_t = time.time()
    root = kdtree_new.kdtree_construction(db_np, leaf_size)
    construction_time_sum += time.time() - begin_t

    begin_t = time.time()
    for i in range(len(db_np)):    #len(db_np)
        query = db_np[i,:]

        kdtree_new.kdtree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("Kdtree_New: build %.3fms, knn %.3fms" % (construction_time_sum  1000, knn_time_sum 1000))

方差建轴的具体实现函数如下

def axis_select(leaf_point):
    # arr_var_x = np.var(leaf_point[:,0])
    # arr_var_y = np.var(leaf_point[:,1])
    # arr_var_z = np.var(leaf_point[:,2])
    # arr_axis_max =  max(arr_var_x,arr_var_y,arr_var_z)
    # if( arr_axis_max == arr_var_x):
    #     #print("axis = 0 ")
    #     return 0
    # elif ( arr_axis_max == arr_var_y):
    #     #print("axis = 1")
    #     return 1
    # else:
    #     #print("axis = 2")
    #     return 2
    arr_var = np.var(leaf_point,axis=0)       #求方差
    arr_axis_max = max(arr_var[0],arr_var[1],arr_var[2])     #选取方差较大的进行轴进行切割
    if( arr_axis_max == arr_var[0]):
        return 0       #axis= 0
    elif ( arr_axis_max == arr_var[1]):
        return 1       #axis = 1
    else:
        return 2       #axis = 2

调用
在这里插入图片描述
在这里插入图片描述

Brute

暴力排序:

    print("Brute --------------")
    brute_time_sum = 0
    brute_time_sum = 0
    begin_t = time.time()
    for i in range(1000):
        query = db_np[i, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
        nn_idx = np.argsort(diff)
        nn_dist = diff[nn_idx]
    brute_time_sum += time.time() - begin_t
    print("1000 points for : brute %.3fms" % (brute_time_sum1000))

测试结果

对数据集中12W个点对前1000个点分别进行8-NN搜寻结果如下:

在这里插入图片描述

对数据集中12W个点对每点分别进行8-NN搜寻结果如下:

在这里插入图片描述

完整代码:

benchamark.py

# 对数据集中的点云,批量执行构建树和查找,包括kdtree和octree,并评测其运行时间

import random
import math
import numpy as np
import time
import os
import struct

import octree as octree
import kdtree as kdtree
import kdtree_new as kdtree_new
from result_set import KNNResultSet, RadiusNNResultSet

from scipy import spatial

def read_velodyne_bin(path):
    '''
    :param path:
    :return: homography matrix of the point cloud, N*3
    '''
    pc_list = []
    with open(path, 'rb') as f:
        content = f.read()
        pc_iter = struct.iter_unpack('ffff', content)
        for idx, point in enumerate(pc_iter):
            pc_list.append([point[0], point[1], point[2]])
    return np.asarray(pc_list, dtype=np.float32)

def main():
    # configuration
    leaf_size = 32          #每个leaf最多有32个点
    min_extent = 0.0001     #octant的最小尺寸
    k = 8                   #搜索8个紧临点
    radius = 20              #radius NN的半径为1


    #读取文件
    filename = 'data/000000.bin'
    db_np = read_velodyne_bin(filename)

    print("octree --------------")
    #时间统计
    construction_time_sum = 0
    knn_time_sum = 0
    radius_time_sum = 0
    brute_time_sum = 0

    #construction
    begin_t = time.time()
    root = octree.octree_construction(db_np, leaf_size, min_extent)
    construction_time_sum += time.time() - begin_t  # 统计构建时间

    #Octree KNNsearch
    begin_t = time.time()
    for i in  range(len(db_np)):    #len(db_np)
        result_set = KNNResultSet(capacity=k)
        query = db_np[i,:]           #对每一个点进行KNN搜索
        octree.octree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("Octree: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum * 1000))

    #
    # #Octree Radiusearch
    # # begin_t = time.time()
    # # #for i in  range(len(db_np)):
    # # for i in range(1000):
    # #     query = db_np[i,:]           #对每一个点进行KNN搜索
    # #     result_set = RadiusNNResultSet(radius=radius)
    # #     octree.octree_radius_search_fast(root, db_np, result_set, query)
    # # radius_time_sum += time.time() - begin_t
    #
    # #brute search
    # begin_t = time.time()
    # #for i in  range(len(db_np)):
    # for i in range(1):
    #     query = db_np[i,:]           #对每一个点进行KNN搜索
    #     diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)     #query的shape为(3,)  db_np的shape为(124668,3),需要使用expand_dims()或者reshape  改为(1,3)
    #     nn_idx = np.argsort(diff)
    #     nn_dist = diff[nn_idx]
    # brute_time_sum += time.time() - begin_t
    # # print("Octree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum*1000/iteration_num,
    # #                                                                  knn_time_sum*1000/iteration_num,
    # #                                                                  radius_time_sum*1000/iteration_num,
    # #                                                                  brute_time_sum*1000/iteration_num))
    # print("Octree: build %.3fms, knn %.3fms,  brute %.3fms" % (construction_time_sum*1000,
    #                                         knn_time_sum*1000,
    #                                         brute_time_sum*1000))



    print("kdtree --------------")
    #spatial.KDTree
    construction_time_sum = 0
    knn_time_sum = 0
    #construction
    begin_t = time.time()
    tree = spatial.KDTree(db_np,leaf_size)
    construction_time_sum += time.time() - begin_t
    #search
    begin_t = time.time()
    tree.query(x=db_np[0:30000,:],k=8)
    knn_time_sum += time.time() - begin_t
    print("Kdtree_spatial: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum * 1000))



    construction_time_sum = 0
    knn_time_sum = 0
    radius_time_sum = 0
    brute_time_sum = 0

    #origin KDtree  顺序建轴
    begin_t = time.time()
    root = kdtree.kdtree_construction(db_np, leaf_size)
    construction_time_sum += time.time() - begin_t

    begin_t = time.time()
    for i in range(len(db_np)):   #len(db_np)
        result_set = KNNResultSet(capacity=k)
        query = db_np[i,:]
        kdtree.kdtree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("Kdtree_Origin: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum * 1000))

    #new KNN   方差建轴
    construction_time_sum = 0
    knn_time_sum = 0
    begin_t = time.time()
    root = kdtree_new.kdtree_construction(db_np, leaf_size)
    construction_time_sum += time.time() - begin_t

    begin_t = time.time()
    for i in range(len(db_np)):    #len(db_np)
        result_set = KNNResultSet(capacity=k)
        query = db_np[i,:]
        kdtree_new.kdtree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("Kdtree_New: build %.3fms, knn %.3fms" % (construction_time_sum * 1000, knn_time_sum *1000))


    # begin_t = time.time()
    # result_set = RadiusNNResultSet(radius=radius)
    # kdtree.kdtree_radius_search(root, db_np, result_set, query)
    # radius_time_sum += time.time() - begin_t
    #

    print("Brute --------------")
    brute_time_sum = 0
    brute_time_sum = 0
    begin_t = time.time()
    for i in range(1000):
        query = db_np[i, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
        nn_idx = np.argsort(diff)
        nn_dist = diff[nn_idx]
    brute_time_sum += time.time() - begin_t
    print("1000 points for : brute %.3fms" % (brute_time_sum*1000))
    # print("Kdtree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum * 1000 ,
    #                                                                  knn_time_sum * 1000 ,
    #                                                                  radius_time_sum * 1000 ,
    #                                                                  brute_time_sum * 1000 )





if __name__ == '__main__':
    main()

 

kdtree.py

# kdtree的具体实现,包括构建和查找

import random
import math
import numpy as np
import time
from result_set import KNNResultSet, RadiusNNResultSet

# Node类,Node是tree的基本组成元素
class Node:
    def __init__(self, axis, value, left, right, point_indices):
        self.axis = axis
        self.value = value
        self.left = left
        self.right = right
        self.point_indices = point_indices

    def is_leaf(self):
        if self.value is None:
            return True
        else:
            return False

    def __str__(self):
        output = ''
        output += 'axis %d, ' % self.axis
        if self.value is None:
            output += 'split value: leaf, '
        else:
            output += 'split value: %.2f, ' % self.value
        output += 'point_indices: '
        output += str(self.point_indices.tolist())
        return output

# 功能:构建树之前需要对value进行排序,同时对一个的key的顺序也要跟着改变
# 输入:
#     key:键
#     value:值
# 输出:
#     key_sorted:排序后的键
#     value_sorted:排序后的值
def sort_key_by_vale(key, value):
    assert key.shape == value.shape       #assert 断言操作,用于判断一个表达式,在表达式条件为false的时候触发异常
    assert len(key.shape) == 1            #numpy是多维数组
    sorted_idx = np.argsort(value)        #对value值进行排序
    key_sorted = key[sorted_idx]
    value_sorted = value[sorted_idx]      #进行升序排序
    return key_sorted, value_sorted


def axis_round_robin(axis, dim):
    if axis == dim-1:
        return 0
    else:
        return axis + 1


# 功能:通过递归的方式构建树
# 输入:
#     root: 树的根节点
#     db: 点云数据
#     point_indices:排序后的键
#     axis: scalar
#     leaf_size: scalar
# 输出:
#     root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
    if root is None:
        root = Node(axis, None, None, None, point_indices)

    # determine whether to split into left and right
    if len(point_indices) > leaf_size:
        # --- get the split position ---
        point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis])  # point_indices_sorted通过axis排序后的key,dp[point_indices,axis]提取当前axis下的点

        # 作业1
        # 屏蔽开始
        left_idx = math.ceil(point_indices_sorted.shape[0] / 2)  # ceil()函数用于从上取整  计算出左边有多少个点
        left_point_idx = point_indices_sorted[left_idx - 1]  # 左边节点 的最大值
        left_point_value = db[left_point_idx - 1, axis]  # 提取值

        right_idx = left_idx  # 右边的点数
        right_point_idx = point_indices_sorted[right_idx]  # 右边节点 的最小值
        right_point_value = db[right_point_idx, axis]  # 提取值

        root.value = (right_point_value + left_point_value) * 0.5  # 取middle为 root的值
        #进行递归分割
        #小值放左边
        root.left = kdtree_recursive_build(root.left,
                                           db,
                                           point_indices_sorted[0:right_idx],
                                           axis_round_robin(axis,dim=db.shape[1]),
                                           #axis_select(db[point_indices_sorted[0:right_idx]]),
                                           leaf_size)
        #大值放右边
        root.right = kdtree_recursive_build(root.right,
                                            db,
                                            point_indices_sorted[right_idx:],
                                            axis_round_robin(axis,dim=db.shape[1]),
                                            #axis_select(db[point_indices_sorted[right_idx:]]),
                                            leaf_size)
        # 屏蔽结束
    return root


# 功能:翻转一个kd树
# 输入:
#     root:kd树
#     depth: 当前深度
#     max_depth:最大深度
def traverse_kdtree(root: Node, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]

    if root.is_leaf():
        print(root)
    else:
        traverse_kdtree(root.left, depth, max_depth)
        traverse_kdtree(root.right, depth, max_depth)

    depth[0] -= 1

# 功能:构建kd树(利用kdtree_recursive_build功能函数实现的对外接口)
# 输入:
#     db_np:原始数据
#     leaf_size:scale
# 输出:
#     root:构建完成的kd树
def kdtree_construction(db_np, leaf_size):
    N, dim = db_np.shape[0], db_np.shape[1]

    # build kd_tree recursively
    root = None
    root = kdtree_recursive_build(root,
                                  db_np,
                                  np.arange(N),
                                  axis = 0,         #axis = axis_select(db_np)  or  axis = 0
                                  leaf_size=leaf_size)
    return root


# 功能:通过kd树实现knn搜索,即找出最近的k个近邻
# 输入:
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query:索引信息
# 输出:
#     搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():                             #如果搜索到是叶子节点,直接进行暴力搜索
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)     #求距离
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False

    # 作业2
    # 提示:仍通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis] <= root.value:             #query[root.axis] 当前目标点的在对应axis上的值    <   当前主节点的value    对left进行搜索
        kdtree_knn_search(root.left, db, result_set, query)

        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():       #主节点的值 与 目标值 差值小于worst_dist 要对右边进行搜寻
            kdtree_knn_search(root.right, db, result_set, query)

    else:
        kdtree_knn_search(root.right, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():  # 与上相反
            kdtree_knn_search(root.left, db, result_set, query)

    # 屏蔽结束

    return False

# 功能:通过kd树实现radius搜索,即找出距离radius以内的近邻
# 输入:
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query:索引信息
# 输出:
#     搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False

    # 作业3
    # 提示:通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis] <= root.value:
        kdtree_radius_search(root.left, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_radius_search(root.right, db, result_set, query)
    else:
        kdtree_radius_search(root.right, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_radius_search(root.left, db, result_set, query)
    # 屏蔽结束

    return False



def main():
    construction_time_sum = 0
    knn_time_sum = 0
    # configuration
    db_size = 640000
    dim = 3
    leaf_size = 4
    k = 8

    db_np = np.random.rand(db_size, dim)
    #construction
    begin_t = time.time()
    root = kdtree_construction(db_np, leaf_size=leaf_size)
    construction_time_sum += time.time() - begin_t

    depth = [0]
    max_depth = [0]
    traverse_kdtree(root, depth, max_depth)
    print("tree max depth: %d" % max_depth[0])

    result_set = KNNResultSet(capacity=k)
    #query = np.asarray([0, 0, 0])
    begin_t = time.time()
    for i in range(1):
        query = db_np[i,:]
        #kdtree search

        kdtree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("buile  %sms KNN  %sms" %(construction_time_sum*1000,knn_time_sum*1000))
    #
    # print(result_set)
    #
    # diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
    # nn_idx = np.argsort(diff)
    # nn_dist = diff[nn_idx]
    # print(nn_idx[0:k])
    # print(nn_dist[0:k])
    #
    #
    # print("Radius search:")
    # query = np.asarray([0, 0, 0])
    # result_set = RadiusNNResultSet(radius = 0.5)
    # radius_search(root, db_np, result_set, query)
    # print(result_set)


if __name__ == '__main__':
    main()

 

kdtree_new.py

# kdtree的具体实现,包括构建和查找

import random
import math
import numpy as np
import time
from result_set import KNNResultSet, RadiusNNResultSet

# Node类,Node是tree的基本组成元素
class Node:
    def __init__(self, axis, value, left, right, point_indices):
        self.axis = axis
        self.value = value
        self.left = left
        self.right = right
        self.point_indices = point_indices

    def is_leaf(self):
        if self.value is None:
            return True
        else:
            return False

    def __str__(self):
        output = ''
        output += 'axis %d, ' % self.axis
        if self.value is None:
            output += 'split value: leaf, '
        else:
            output += 'split value: %.2f, ' % self.value
        output += 'point_indices: '
        output += str(self.point_indices.tolist())
        return output

# 功能:构建树之前需要对value进行排序,同时对一个的key的顺序也要跟着改变
# 输入:
#     key:键
#     value:值
# 输出:
#     key_sorted:排序后的键
#     value_sorted:排序后的值
def sort_key_by_vale(key, value):
    assert key.shape == value.shape       #assert 断言操作,用于判断一个表达式,在表达式条件为false的时候触发异常
    assert len(key.shape) == 1            #numpy是多维数组
    sorted_idx = np.argsort(value)        #对value值进行排序
    key_sorted = key[sorted_idx]
    value_sorted = value[sorted_idx]      #进行升序排序
    return key_sorted, value_sorted


def axis_round_robin(axis, dim):
    if axis == dim-1:
        return 0
    else:
        return axis + 1

def axis_select(leaf_point):
    # arr_var_x = np.var(leaf_point[:,0])
    # arr_var_y = np.var(leaf_point[:,1])
    # arr_var_z = np.var(leaf_point[:,2])
    # arr_axis_max =  max(arr_var_x,arr_var_y,arr_var_z)
    # if( arr_axis_max == arr_var_x):
    #     #print("axis = 0 ")
    #     return 0
    # elif ( arr_axis_max == arr_var_y):
    #     #print("axis = 1")
    #     return 1
    # else:
    #     #print("axis = 2")
    #     return 2
    arr_var = np.var(leaf_point,axis=0)       #求方差
    arr_axis_max = max(arr_var[0],arr_var[1],arr_var[2])     #选取方差较大的进行轴进行切割
    if( arr_axis_max == arr_var[0]):
        return 0       #axis= 0
    elif ( arr_axis_max == arr_var[1]):
        return 1       #axis = 1
    else:
        return 2       #axis = 2





# 功能:通过递归的方式构建树
# 输入:
#     root: 树的根节点
#     db: 点云数据
#     point_indices:排序后的键
#     axis: scalar
#     leaf_size: scalar
# 输出:
#     root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
    if root is None:
        root = Node(axis, None, None, None, point_indices)

    # determine whether to split into left and right
    if len(point_indices) > leaf_size:
        # --- get the split position ---
        point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis])  # point_indices_sorted通过axis排序后的key,dp[point_indices,axis]提取当前axis下的点

        # 作业1
        # 屏蔽开始
        left_idx = math.ceil(point_indices_sorted.shape[0] / 2)  # ceil()函数用于从上取整  计算出左边有多少个点
        left_point_idx = point_indices_sorted[left_idx - 1]  # 左边节点 的最大值
        left_point_value = db[left_point_idx - 1, axis]  # 提取值

        right_idx = left_idx  # 右边的点数
        right_point_idx = point_indices_sorted[right_idx]  # 右边节点 的最小值
        right_point_value = db[right_point_idx, axis]  # 提取值

        root.value = (right_point_value + left_point_value) * 0.5  # 取middle为 root的值
        #进行递归分割
        #小值放左边
        root.left = kdtree_recursive_build(root.left,
                                           db,
                                           point_indices_sorted[0:right_idx],
                                           #axis_round_robin(axis,dim=db.shape[1]),
                                           axis_select(db[point_indices_sorted[0:right_idx]]),
                                           leaf_size)
        #大值放右边
        root.right = kdtree_recursive_build(root.right,
                                            db,
                                            point_indices_sorted[right_idx:],
                                            #axis_round_robin(axis,dim=db.shape[1]),
                                            axis_select(db[point_indices_sorted[right_idx:]]),
                                            leaf_size)
        # 屏蔽结束
    return root


# 功能:翻转一个kd树
# 输入:
#     root:kd树
#     depth: 当前深度
#     max_depth:最大深度
def traverse_kdtree(root: Node, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]

    if root.is_leaf():
        print(root)
    else:
        traverse_kdtree(root.left, depth, max_depth)
        traverse_kdtree(root.right, depth, max_depth)

    depth[0] -= 1

# 功能:构建kd树(利用kdtree_recursive_build功能函数实现的对外接口)
# 输入:
#     db_np:原始数据
#     leaf_size:scale
# 输出:
#     root:构建完成的kd树
def kdtree_construction(db_np, leaf_size):
    N, dim = db_np.shape[0], db_np.shape[1]

    # build kd_tree recursively
    root = None
    root = kdtree_recursive_build(root,
                                  db_np,
                                  np.arange(N),
                                  axis = axis_select(db_np),         #axis = axis_select(db_np)  or  axis = 0
                                  leaf_size=leaf_size)
    return root


# 功能:通过kd树实现knn搜索,即找出最近的k个近邻
# 输入:
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query:索引信息
# 输出:
#     搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():                             #如果搜索到是叶子节点,直接进行暴力搜索
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)     #求距离
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False

    # 作业2
    # 提示:仍通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis] <= root.value:             #query[root.axis] 当前目标点的在对应axis上的值    <   当前主节点的value    对left进行搜索
        kdtree_knn_search(root.left, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():       #主节点的值 与 目标值 差值小于worst_dist 要对右边进行搜寻
            kdtree_knn_search(root.right, db, result_set, query)

    else:
        kdtree_knn_search(root.right, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():  # 与上相反
            kdtree_knn_search(root.left, db, result_set, query)

    # 屏蔽结束

    return False

# 功能:通过kd树实现radius搜索,即找出距离radius以内的近邻
# 输入:
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query:索引信息
# 输出:
#     搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False
    
    # 作业3
    # 提示:通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis] <= root.value:
        kdtree_radius_search(root.left, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_radius_search(root.right, db, result_set, query)
    else:
        kdtree_radius_search(root.right, db, result_set, query)
        if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
            kdtree_radius_search(root.left, db, result_set, query)
    # 屏蔽结束

    return False



def main():
    construction_time_sum = 0
    knn_time_sum = 0
    # configuration
    db_size = 640000
    dim = 3
    leaf_size = 4
    k = 8

    db_np = np.random.rand(db_size, dim)
    #construction
    begin_t = time.time()
    root = kdtree_construction(db_np, leaf_size=leaf_size)
    construction_time_sum += time.time() - begin_t

    depth = [0]
    max_depth = [0]
    traverse_kdtree(root, depth, max_depth)
    print("tree max depth: %d" % max_depth[0])

    result_set = KNNResultSet(capacity=k)
    #query = np.asarray([0, 0, 0])
    begin_t = time.time()
    for i in range(1):
        query = db_np[i,:]
        #kdtree search

        kdtree_knn_search(root, db_np, result_set, query)
    knn_time_sum += time.time() - begin_t
    print("buile  %sms KNN  %sms" %(construction_time_sum*1000,knn_time_sum*1000))
    #
    # print(result_set)
    #
    # diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
    # nn_idx = np.argsort(diff)
    # nn_dist = diff[nn_idx]
    # print(nn_idx[0:k])
    # print(nn_dist[0:k])
    #
    #
    # print("Radius search:")
    # query = np.asarray([0, 0, 0])
    # result_set = RadiusNNResultSet(radius = 0.5)
    # radius_search(root, db_np, result_set, query)
    # print(result_set)


if __name__ == '__main__':
    main()

result_set.py

# 该文件定义了在树中查找数据所需要的数据结构,类似一个中间件

import copy


class DistIndex:
    def __init__(self, distance, index):
        self.distance = distance
        self.index = index

    def __lt__(self, other):
        return self.distance < other.distance


class KNNResultSet:
    def __init__(self, capacity):
        self.capacity = capacity
        self.count = 0
        self.worst_dist = 1e10
        self.dist_index_list = []
        for i in range(capacity):
            self.dist_index_list.append(DistIndex(self.worst_dist, 0))

        self.comparison_counter = 0

    def size(self):
        return self.count

    def full(self):
        return self.count == self.capacity

    def worstDist(self):
        return self.worst_dist

    def add_point(self, dist, index):
        self.comparison_counter += 1
        if dist > self.worst_dist:
            return

        if self.count < self.capacity:
            self.count += 1

        i = self.count - 1
        while i > 0:
            if self.dist_index_list[i-1].distance > dist:
                self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
                i -= 1
            else:
                break

        self.dist_index_list[i].distance = dist
        self.dist_index_list[i].index = index
        self.worst_dist = self.dist_index_list[self.capacity-1].distance
        
    def __str__(self):
        output = ''
        for i, dist_index in enumerate(self.dist_index_list):
            output += '%d - %.2f\n' % (dist_index.index, dist_index.distance)
        output += 'In total %d comparison operations.' % self.comparison_counter
        return output


class RadiusNNResultSet:
    def __init__(self, radius):
        self.radius = radius
        self.count = 0
        self.worst_dist = radius
        self.dist_index_list = []

        self.comparison_counter = 0

    def size(self):
        return self.count

    def worstDist(self):
        return self.radius

    def add_point(self, dist, index):
        self.comparison_counter += 1
        if dist > self.radius:
            return

        self.count += 1
        self.dist_index_list.append(DistIndex(dist, index))

    def __str__(self):
        self.dist_index_list.sort()
        output = ''
        for i, dist_index in enumerate(self.dist_index_list):
            output += '%d - %.2f\n' % (dist_index.index, dist_index.distance)
        output += 'In total %d neighbors within %f.\nThere are %d comparison operations.' \
                  % (self.count, self.radius, self.comparison_counter)
        return output

 

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值