通过pointnet、pointnet++ 了解、熟悉pytorch小总结

1.在pointnet++里面,需要对点云数据进行sampling 和grouping,就是利用最远距离筛选三维点云的点(降采样),然后对这些采样的点,搜索每个点邻域内K个点,这就要求在计算出这些采样邻域点的K个索引后,利用这些索引获得新的数据矩阵,考虑到batch,假设batch数据为data(B*N*3),每个点云采样n个点,每一个采样点搜索邻域内k个点,生产的新的数据应用是    (B,n,k,3).。假设 已经计算出采样点的索引sample_index(B,n)已及每个采样点的k个邻域点索引 grouped_idx(B,n,k)

一开始直接用data[:,grouped_idx,:] 求,但这个是错的,因为data[:,grouped_idx,:]中第一维取的是(B,),第二维取的是grouped_idx是(B,n,k),这样会导致第一维B中的每一个都取了B*n*k个,生成的数据就是(B,B,n,k),而不是对应grouped_idx中的B,一一对应,即第一维中的B与第二维的B不能自动对应。要想得到想要的结果,需要把第一维的切片的size 与第二维的一样,切片的第一维对应的是batchsize 因此切片的索引都是batch,即要repeat,可以通过以下方法获得:

points = torch.rand(5,6,3)
print(points)
idx = np.random.randint(0,6,size=30).reshape(5,2,3)
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).view(view_shape).repeat(repeat_shape)
print(batch_indices)
new_points = points[batch_indices, idx, :]

2.算一个点与其他点的距离 等,尽量用矩阵计算 方式,比如norm为2的欧氏距离可以用以下方式计算:

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

我自己复现时直接调用了open3d里的kd树搜索找出邻近点,而不算完与所有点的距离。

3.最远点采样,巧妙的地方是利用蒙板实现,起位置直接用序号作为该点的索引,如下:

def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

4.在pointnet++ 多个group 点进行mlp时,刚开始傻乎乎的用循环进行多个pointnet计算,其实多个group就可以组成2d的mlp,及kenerl为1的conv2d就行了,毕竟这些group的mlp是共享的。

5总之,矩阵的切片取数,mask的使用可以替代for循环,多用多熟悉!

6.pytorch几个高级选择函数(如gather)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值