3D感知(kdtree)

对于代码这一块:

class Node:
    def __init__(self, left, right, point_indices):
         self.left = left
         self.right = right
         self.point_indices = point_indices

节点的数据结构为 左右节点 以及分割点。

# 构造树
# 按照某一个维度轮流切,切的位置取决于某个维度如(x)的排序后的中值
def make_kdtree(points,dim,i=0):
    if len(points)>1:
        points=points[np.argsort(points[:,i]),:]
        i=(i+1)%dim
        half=len(points)//2
        return(make_kdtree(points[: half,:],dim,i),
               make_kdtree(points[half+1 :,:],dim,i),
                points[half,:])
    elif len(points)==1:
        return (None,None,points[0,:])

这里要注意建树过程中,建树中叶子节点的判定条件。我这里是最后只是一个点了,就是叶子节点。

建树完成后,可以执行最近点搜索:

def get_nearst(kd_node,point,dim,dist_func,return_distance=False,i=0,best=None):
    if kd_node:
        dist=dist_func(point,kd_node[2])
        dx=kd_node[2][i]-point[i]

        if not best:
            best=[dist,kd_node[2]]
        elif dist<best[0]:
            best[0],best[1]=dist,kd_node[2]
        i=(i+1)%dim

        get_nearst(kd_node[dx<0],point,dim,dist_func,return_distance,i,best)

        if dx**2<best[0]:
            get_nearst(kd_node[dx>=0],point,dim,dist_func,return_distance,i,best)

    return best if return_distance else best[1]

找最近的点的过程:计算该点到节点分割点的距离,以及该点到分割平面的距离(这个是判断是否搜索另外一个子空间节点的条件)

best这个数组就是存储节点到分割点距离,以及节点到分割平面的距离。

整个流程就是:首先计算当前节点的best ,搜索该节点在的子节点 ,判断该点到平面距离和点到分割点的距离,决定是否要搜索下一个子节点。再重复递归下去。

下面就是k 近邻的搜索了:

def get_knn(kd_node,point,k,dim,dist_func,return_distance=False,i=0,heap=None):
    is_root= not heap
    if is_root:
        heap=[]
    if kd_node:
        dist=dist_func(point,np.array(kd_node[2]))
        dx=kd_node[2][i]-point[i]
        if len(heap)<k:
            heapq.heappush(heap,(-dist,kd_node[2].tolist()))
        elif dist<-heap[0][0]:
            heapq.heappushpop(heap,(-dist,kd_node[2].tolist()))
        i=(i+1)%dim
        get_knn(kd_node[dx<0],point,k,dim,dist_func,return_distance,i,heap)
        if dx**2<-heap[0][0] or len(heap)<k:
            get_knn(kd_node[dx >= 0], point, k, dim, dist_func, return_distance, i, heap)

    if is_root:
        idx=np.argsort([-h[0] for h in heap])
        neighors=[(-heap[n][0],np.array(heap[n][1])) for n in idx]
        return neighors if return_distance else [n[1] for n in neighors]

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值