PointNet论文解读及代码详解

一、论文解读

        PointNet是第一个直接以三维点云作为输入的深度学习网络,点云的特点如下:

        # 无序:无论点的顺序如何打乱,点云传达的内容不变。

        # 组合表达:某个点需要结合附近的点才能表达局部信息,仅看单个点无任何意义。

        # 变换无关:将点云中的所有点同时进行平移、旋转,不改变点云传达的信息。

        PointNet在结构设计时考虑了点云的上述特点,因此在输入点云中点的顺序发生改变时,其输出一定不会改变。此外在少量点缺失或者错误时,也能保证结果的稳定性。

        PointNet 有两个网络,一是分类,二是分割,分类即输入点云,输出其表达的物体类别。分割即输入点云,将其表达的物体的不同部件分开。

1.1 分类网络

        分类网络的结构如图1所示:

图1. PointNet网络结构

        下面介绍每一步数据如何进行运算。

        从第一个n×3到第二个n×3矩阵的计算过程如图2所示,该层主要使用input transformation这个3×3的矩阵对输入点云进行三维变换,比较简单。input transformation矩阵中数值需要通过一个T-Net进行学习,这个T-Net名为STN3d(代码中作者写的名字,STN 为 Spatial Transformer Network 的简写,即空间变换网络)。

图2. PointNet第一层网络计算过程

         T-Net(STN3d)的结构如图3所示。

图3. STN3d 结构图

         Shared MLP 为1×1的卷积层,其计算过程如图4所示,图中以Shared MLP(3, m) 为例,输入通道数为3,输出通道数为m,因为输出通道为m个,所以有m个卷积核,每个卷积核对应其中一个通道,图中输入的尺寸为(1,3,n,1),括号内四个数依次表示批量、通道、行、列。因此输出的尺寸为(1,m,n,1)。图中

d_1=a_1\times x_1+b_1\times y_1+c_1\times z_1\\d_2=a_1\times x_2+b_1\times y_2+c_1\times z_2\\ ...\\ d_n=a_1\times x_n+b_1\times y_n+c_1\times z_n

图4. Shared MLP 结构图

         分类网络的第二层计算过程如图5所示,Shared MLP的计算参考图4。

图5. PointNet第二层计算过程

         分类网络第三层计算过程如图6所示,是将n×64的矩阵乘一个64×64的矩阵,该矩阵由T-Net(STNkd)生成,STNkd的结构和STN3d类似,不过多了个正则化项,使输出倾向于得到正交矩阵,加速收敛。

图6. PointNet第三层计算过程

 分类网络第四层计算过程如图7所示,和第二层类似。

图7. PointNet第四层计算过程

         第五层为最大池化层,第六层为全连接层,其计算过程如图8所示。输出结果为一个长度为k的向量,向量中的每个数代表点云所表示的物体属于某类的概率。

图8. PointNet第五、六层计算过程

 1.2 分割网络

二、代码解析

2.1 基础

(使用类的继承是为了将自定义对象作为参数传入pytorch函数)

  • 继承nn.Module类

自定义类继承nn.Module类后,只需要重写__init__forward函数即可。

其中__init__函数用于定义有哪些层,forward函数定义前向传播过程。

例如类PointNetCls继承nn.Module,并重写了__init__forward

class PointNetCls(nn.Module):

    def __init__(self, k=2, feature_transform=False):

    def forward(self, x):
则若有一个PointNetCls类的对象classfier=PointNetCls()

那么classfier(x)等价于classfier.forward(x)

  • 继承data.Dataset类

在继承data.Dataset类后,需要重写__init____getitem____len__函数。

例如类ModelData继承nn.Module,并重写了__init____getitem____len__函数。

class ShapeNetDataset(data.Dataset):

    def __init__(self):

    def __getitem__(self, index):

    def __len__(self):

则若有一个ModelData类的对象datas=ModelData()

那么datas[i]等价于datas.__getitem__(i)len(datas)=datas.__len__()

2.1 分类网络

       

2.2 分割网络

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值