PointNet++改进策略 :模块改进 | OE Unit | PointSIFT,结合方向信息提升模型精度

Pasted image 20240830100720

  • 论文:PointSIFT: A SIFT-like Network Module for 3D Point Cloud Semantic Segmentation
  • 来源:ECCV 2020
  • 机构:清华大学 & 上海交通大学
  • 论文:http://arxiv.org/abs/1807.00652
  • 代码:
    • https://github.com/MVIG-SJTU/pointSIFT/
    • https://github.com/lelouedec/3DNetworksPytorch

Pasted image 20240830100712

  • 灵感来源:来着SIFT方法,该方法是提取2D图像中图像特征的方法,来生成类似图像中形状的描述
  • 方法实现:设计一个网络PointSIFT,能够提取物体中方向的信息同时能适应不同尺度的物体
  • 实验结果:在ScanNet数据集上语义分割任务中,相比PointNet++网络IOU提升2.3%Acc提升3.22%

网络整体

整体网络如下,PointSIFT对PointNet++中的MLP层进行了改进。
Pasted image 20240830101209

PointSIFT模块

通过堆叠多个 OE 单元,使得网络能够适应不同尺度的形状特征。在 PointSIFT 模块中堆叠多个 OE 单元,每个单元的感受野逐渐增大,以捕捉更大范围内的局部特征。使用shortcuts将各级 OE 单元的输出连接起来,然后通过一个点卷积(point-wise convolution)将多尺度特征融合,最终输出具有多尺度感知能力的特征表示。

Pasted image 20240830100754

Orientation-Encoding Unit

OE单元是一个方向编码单元来描述八个关键方向。

  • 其中主要两个步骤:
    • 八邻域搜索(Stacked 8-Neighborhood Search)对于每个输入点,PointSIFT 首先按照坐标轴将空间划分为八个象限,并在每个象限中找到距离最近的点作为邻居。如果某个象限内没有点,则将输入点自身复制为其最近邻。
      方向编码卷积(Orientation-Encoding Convolution, OEC):在八邻域点上进行三阶段的卷积操作(沿 X, Y, Z 轴),将这些点的特征进行融合。具体来说,卷积操作将 2×2×2 立方体中的特征依次沿各轴进行卷积,最终输出包含方向信息的特征表示。

Pasted image 20240830101158

代码实现

我在文章顶端中,有俩个PointSIFT实现版本分别是pytorch和tensorflow实现,大家根据需求迁移到自己的项目中,下面我讲以pytorch版本网络设计对代码进行注释,具体的实现细节可以参考上面的链接中

class PointSIFT(nn.Module):
    def __init__(self, nb_classes):
        super(PointSIFT, self).__init__()

        self.num_classes = nb_classes

        # 第一个 PointSIFT 残差模块,用于提取局部特征,半径为 0.1,输出通道为 64。
        self.pointsift_res_m3 = PointSIFT_res_module(radius=0.1, output_channel=64, merge='concat')

        # 第一个 PointNet 下采样模块,采样 1024 个点,半径为 0.1,32 个邻居点,输出特征维度为 128。
        self.pointnet_sa_m3 = Pointnet_SA_module(npoint=1024, radius=0.1, nsample=32, in_channel=64, mlp=[64, 128], group_all=False)

        # 第二个 PointSIFT 残差模块,半径为 0.2,输出通道为 128。
        self.pointsift_res_m4 = PointSIFT_res_module(radius=0.2, output_channel=128, extra_input_channel=128)

        # 第二个 PointNet 下采样模块,采样 256 个点,半径为 0.2,32 个邻居点,输出特征维度为 256。
        self.pointnet_sa_m4 = Pointnet_SA_module(npoint=256, radius=0.2, nsample=32, in_channel=128, mlp=[128, 256], group_all=False)

        # 第三个 PointSIFT 残差模块,第一个子模块,半径为 0.2,输出通道为 256。
        self.pointsift_res_m5_1 = PointSIFT_res_module(radius=0.2, output_channel=256, extra_input_channel=256)

        # 第三个 PointSIFT 残差模块,第二个子模块,半径为 0.2,输出通道为 512。
        self.pointsift_res_m5_2 = PointSIFT_res_module(radius=0.2, output_channel=512, extra_input_channel=256, same_dim=True)

        # 1D 卷积层,用于合并和处理 PointSIFT 残差模块的输出,输入通道 768,输出通道 512。
        self.conv1 = conv1d_bn(768, 512, 1, stride=1, activation='none')

        # 第四个 PointNet 下采样模块,采样 64 个点,半径为 0.2,32 个邻居点,输出特征维度为 512。
        self.pointnet_sa_m6 = Pointnet_SA_module(npoint=64, radius=0.2, nsample=32, in_channel=512, mlp=[512, 512], group_all=False)

        # 第一个 PointNet 上采样模块,将较低分辨率特征恢复到较高分辨率,输入和输出通道均为 512。
        self.pointnet_fp_m0 = Pointnet_fp_module([512, 512], [512, 512])

        # 一系列 PointSIFT 模块,半径为 0.5,输出通道均为 512。
        self.pointsift_m0 = PointSIFT_module(radius=0.5, output_channel=512, extra_input_channel=512)
        self.pointsift_m1 = PointSIFT_module(radius=0.5, output_channel=512, extra_input_channel=512)
        self.pointsift_m2 = PointSIFT_module(radius=0.5, output_channel=512, extra_input_channel=512)

        # 1D 卷积层,用于合并 PointSIFT 模块的输出,输入通道 512,输出通道 512。
        self.conv2 = conv1d_bn(512, 512, 1, stride=1, activation='none')

        # 第二个 PointNet 上采样模块,将较低分辨率特征恢复到较高分辨率,输入和输出通道均为 256。
        self.pointnet_fp_m1 = Pointnet_fp_module([256, 256], [256, 256])

        # 一系列 PointSIFT 模块,半径为 0.25,输出通道均为 256。
        self.pointsift_m3 = PointSIFT_module(radius=0.25, output_channel=256, extra_input_channel=256)
        self.pointsift_m4 = PointSIFT_module(radius=0.25, output_channel=256, extra_input_channel=256)

        # 1D 卷积层,用于合并 PointSIFT 模块的输出,输入通道 256,输出通道 256。
        self.conv3 = conv1d_bn(256, 256, 1, stride=1, activation='none')

        # 第三个 PointNet 上采样模块,将较低分辨率特征恢复到较高分辨率,输入和输出通道均为 128。
        self.pointnet_fp_m2 = Pointnet_fp_module([128, 128, 128], [128, 128, 128])

        # 第五个 PointSIFT 模块,半径为 0.1,输出通道为 128。
        self.pointsift_m5 = PointSIFT_module(radius=0.1, output_channel=128, extra_input_channel=128)

        ### 全连接层

        # 1D 卷积层,用于分类任务的全连接层,输入通道 128,输出通道 128。
        self.conv_fc = conv1d_bn(128, 128, 1, stride=1, activation='none')

        # Dropout 层,防止过拟合,保留率为 0.5。
        self.drop_fc = nn.Dropout(p=0.5)

        # 最终的分类层,1D 卷积层,输出通道数为 2(假设分类为 2 类)。
        self.conv2_fc = conv1d_bn(128, 2, 1, stride=1, activation='none')

    def forward(self, xyz, points=None):
        """
        Input:
            xyz: 输入的点云数据,形状为 (B * N * 3),B 为批次大小,N 为点的数量,3 为坐标维度。
            points: 附加的点特征,默认为 None。

        """
        B = xyz.size()[0]

        # 第一层 PointSIFT 残差模块 + PointNet 下采样模块
        l3_xyz, l3_points = self.pointsift_res_m3(xyz, points)
        c3_xyz, c3_points = self.pointnet_sa_m3(l3_xyz, l3_points)

        # 第二层 PointSIFT 残差模块 + PointNet 下采样模块
        l4_xyz, l4_points = self.pointsift_res_m4(c3_xyz, c3_points)
        c4_xyz, c4_points = self.pointnet_sa_m4(l4_xyz, l4_points)

        # 第三层 PointSIFT 残差模块
        l5_xyz, l5_points = self.pointsift_res_m5_1(c4_xyz, c4_points)
        l5_2_xyz, l5_2_points = self.pointsift_res_m5_2(l5_xyz, l5_points)

        # 将第三层 PointSIFT 残差模块的两个输出拼接在一起
        l2_cat_points = torch.cat([l5_points, l5_2_points], dim=2)

        # 通过 1D 卷积层处理拼接后的特征
        fc_l2_points = self.conv1(l2_cat_points.permute(0, 2, 1)).permute(0, 2, 1)

        # 第四层 PointNet 下采样模块
        l3b_xyz, l3b_points = self.pointnet_sa_m6(l5_2_xyz, fc_l2_points)

        # 第一层 PointNet 上采样模块
        l2_points = self.pointnet_fp_m0(c4_xyz, l3b_xyz, c4_points, l3_points).permute(0, 2, 1)


        _, l2_points_1 = self.pointsift_m0(c4_xyz, l2_points)
        _, l2_points_2 = self.pointsift_m1(c4_xyz, l2_points)
        _, l2_points_3 = self.pointsift_m2(c4_xyz, l2_points)

        # 将 PointSIFT 模块的输出拼接在一起并通过 1D 卷积层处理
        l2_points = torch.cat([l2_points_1, l2_points_2, l2_points_3], dim=-1)
        l2_points = self.conv2(l2_points)

        # 第二层 PointNet 上采样模块
        l1_points = self.pointnet_fp_m1(c3_xyz, c4_xyz, c3_points, l2_points).permute(0, 2, 1)


        _, l1_points_1 = self.points

        _, l1_points_1 = self.pointsift_m3(c3_xyz, l1_points)
        _, l1_points_2 = self.pointsift_m4(c3_xyz, l1_points)

        # 将 PointSIFT 模块的输出拼接在一起
        l1_points = torch.cat([l1_points_1, l1_points_2], dim=-1)

        # 通过 1D 卷积层处理拼接后的特征
        l0_points = self.conv3(l1_points)

        # 第三层 PointNet 上采样模块
        l0_points = self.pointnet_fp_m2(l3_xyz, c3_xyz, l3_points, l0_points).permute(0, 2, 1)

        # 最后一层 PointSIFT 模块
        _, l0_points_1 = self.pointsift_m5(l3_xyz, l0_points)

        # 全连接层用于分类
        net = self.conv_fc(l0_points_1)
        net = self.drop_fc(net)
        net = self.conv2_fc(net)

        return net

    @staticmethod
    def get_loss(input, target):
        """
        损失函数,计算输入与目标之间的交叉熵损失。
        """
        classify_loss = nn.CrossEntropyLoss()
        loss = classify_loss(input, target)
        return loss

    def initialize_weights(self):
        """
        权重初始化方法,对模型中的卷积层、批量归一化层和全连接层进行初始化。
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


如何改进PointNet++

这篇论文中贡献了两个模块分别是PointSIFT和Orientation-Encoding Unit,这两个模块能够直接替换PointNet++中MLP层,或者跟其他模块在进行组合形成新的模块进行创新模块。

  • 8
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值