K近邻算法 (KNN)和最远点采样(FPS)实现--python+pytorch

本文介绍了如何使用Python和PyTorch实现K近邻(KNN)算法,以及最远点采样(FPS)方法。首先,通过生成点集,展示了KNN算法的详细步骤,计算每个点到某点的距离并找到最近的K个点。接着,给出了一个完整的KNN示例,包括计算距离和寻找近邻点下标。此外,还提供了远点采样的实现,通过不断选取最远的点进行采样,生成指定数量的采样点集。
摘要由CSDN通过智能技术生成

KNN算法实现--python+pytorch

K近邻算法 (KNN)

主要思路:计算每个点和某点的距离,取距离最短的K个点的下标即可。
下面是个完整示例,代码复制即可运行

import torch
import time

#生成点集
def coordinate_gen(n):
    """
    生成n个三位点
    return tensor
    dim:n*3
    """
    xyz = torch.rand(size=(n,3))
    return xyz

#计算运行时间
def time_cost(f):

    def run_time(*args,**kwargs):
        start = time.time()
        res = f(*args,**kwargs)
        run_times = time.time()-start
        print("程序执行时间:%.6f s"%(run_times))
        return res

    return run_time

@time_cost
def knn(xyz,xyzs,k=3):
    """
    xyz:key point
    xyzs:all points
    找某点的k近邻个点
    return 近邻点的下标列表
    """
    idx = [0]*k
    distance = torch.sum((xyzs[:,:3]-xyz[:,:3])**2,dim=-1)
    for i in range(k):
        idx[i] = torch.argmin(distance,dim=0)
        distance[int(torch.argmin(distance,dim=0))] = float('inf')
    idx = [int(i) for i in idx]
    return idx

if __name__ == "__main__":
    print('-' * 20, '测试开始', '-' * 20)
    N,k = map(int,input("输入生成点数 和 k的值:").split())
    xyzs = coordinate_gen(N)
    xyz = torch.rand(size=(1,3))
    print("生成的点如下:\n",xyzs,"\n随机生成key point:",xyz)
    print(knn(xyz,xyzs=xyzs,k=k))
    print('-'*20,'测试结束','-'*20)

最远点采样(FPS)

def farthest_point_sample(data,npoints):
    """
    Args:
        data:输入的tensor张量,排列顺序 N,D
        Npoints: 需要的采样点

    Returns:data->采样点集组成的tensor,每行是一个采样点
    """
    N,D = data.shape #N是点数,D是维度
    xyz = data[:,:3] #只需要坐标
    centroids = torch.zeros(size=(npoints,)) #最终的采样点index
    dictance = torch.ones(size=(N,))*1e10 #距离列表,一开始设置的足够大,保证第一轮肯定能更新dictance
    farthest = torch.randint(low=0,high=N,size=(1,)) #随机选一个采样点的index
    for i in range(npoints):
        centroids[i] = farthest
        centroid = xyz[farthest,:]
        dict = ((xyz-centroid)**2).sum(dim=-1)
        mask = dict < dictance
        dictance[mask] = dict[mask]
        farthest = torch.argmax(dictance,dim=-1)
    print(centroids.type(torch.long))
    data= data[centroids.type(torch.long)]
    return data
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值