FPS(最远点采样)算法代码详解

pointnet的详细代码https://github.com/yanx27/Pointnet_Pointnet2_pytorch

今天主要讲讲代码里mask的主要作用

1、   i 为第几次球心 
        每次的球心 cen(i)   
        每个距离为d(i,j)      i为第几次球心,j为所有点的第j个点
        每次cen(i)与所有点求距离得到的dist一维数组  dist(i)  

                
2、distance矩阵第一次更新会全部刷新为dist(1),因为是记录更短的距离;第二次更新,所有点与cen(2)求距离后,dist(2)dist(1)逐点比较,保留d(i,j)更短的。
(这一步可以理解为,若某点 x(j) 与 cen(1) 的 d(1,j) d(2,j) 更短,就保留 d(1,j) ;若 d(2,j ) d(1,j) 更短,就保留 d(2,j)
以此类推,distance每次刷新,都是为了保留更短距离,在更短的距离里提取最大值点作为下一次的球心cen(i)

3、注意的是,每次求距离,cen(i)都会有x(j)重合,此时重合d(i,j)就等于0,直接刷新distance在该点的位置(可以看作是该点已标注为球心)


4、实际上就是每个点选取距离他最近的球心cen(i),最终输出的distance矩阵是每个点到最近球心cen(i)的距离

def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    # 每个bacth的随机球心
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        # npoint个球心索引
        centroids[:, i] = farthest
        # 每个球心位置
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        # 勾股定理(毕达哥拉斯定理、三高定理)我称它为牛逼三高哥斯拉定理
        dist = torch.sum((xyz - centroid) ** 2, -1)
        "掩码的作用是采样距离更短的点"
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids
torch.manual_seed(3407)
xyz = torch.randn(1, 512, 3)
npoint = 128
centriods = farthest_point_sample(xyz, npoint)
# 打印输出所有秋心索引
print(centriods)
tensor([[444, 294, 243,  99, 423, 488, 250, 494, 333,  15, 286, 306, 129,  67,
         132, 332,   9, 112, 460, 271, 421,  79, 220, 121, 230,  12, 222, 464,
         316, 262, 120, 285, 298, 367, 233, 255, 317, 147, 497, 185, 409, 190,
         412, 303, 202, 217, 273, 443, 325,  42, 414, 462, 184, 216, 283, 390,
           4, 291, 245, 371, 503, 471, 145,  52, 223, 211,  51, 164, 295,  68,
         314, 287, 483, 224, 365, 388, 457, 165, 486, 476, 274,  92, 463, 205,
          10, 482,  16, 320, 411, 109,  77, 410,  41, 351, 123, 134, 207,  70,
          24, 455, 467, 375, 127, 369, 334, 304,  73, 509, 310, 116, 176, 117,
         153, 253,   3, 282, 401, 432, 236, 296, 427, 126,  93, 359, 248, 382,
         194, 292]])

这样做的目的是为了使采样数据更加均匀,防止只在初始点附近和第二次点附近疯狂采样

这是我的一组对照实验

    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    "将初始比较距离改为0"
    distance = torch.zeros(B, N).to(device)
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        "改为采样更远的距离"
        mask = dist > distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

同样的输入,同样的随机种子,比前一种方法做对比

tensor([[444, 294,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,  15,
          15,  15]])

  • 8
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值