【代码】最远点采样

代码

返回采样点索引

def farthest_point_sample1(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        distance = torch.min(distance, dist)
        farthest = torch.max(distance, -1)[1]
    return centroids

返回采样点坐标

def farthest_point_sample(xyz, npoint):
    """
    :param xyz: pointcloud data, [N,3]
    :param npoint: number of samples
    :return:
        result: sampled pointcloud points
    """
    device = xyz.device
    N, C = xyz.shape
    centroids = torch.zeros(npoint).to(device)
    distance = torch.ones(N).to(device) * 1e10
    farthest = torch.randint(0, N, ()).to(device)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        centroid = centroid.to(device)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        dist = dist.to(device)
        mask = dist < distance
        mask = mask.to(device)
        dist = dist.float()
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    result = xyz[centroids.long()]
    return result
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值