PointNet++改进策略 :模块改进 | Transformer | Point Transformerv3, 使用Transformer架构引入到3D点云任务中提升模型精度

Pasted image 20240914081958

  • 论文题目:Point Transformer V3: Simpler, Faster, Stronger
  • 发布期刊:CVPR
  • 通讯地址:1香港大学 2上海人工智能实验室 3香港中文大学(深圳) 4北京大学 5麻省理工学院
  • 代码地址: https://github.com/Pointcept/PointTransformerV3

介绍

Pasted image 20240914082444
这篇论文的主要内容是介绍Point Transformer V3 (PTv3),一个用于3D点云处理的模型,目标是解决精度与效率之间的权衡问题。PTv3在设计上强调简化和提升效率,优先考虑模型的扩展能力,而不是通过复杂的机制来提升局部的精度。论文的贡献主要体现在以下几个方面:

  1. 更强的性能:PTv3在多个室内和室外的3D感知任务上取得了最新的最优性能。
  2. 更广的感受野:通过简化设计,PTv3将感受野从16个点扩展到1024个点,同时保持了高效性。
  3. 更快的速度:PTv3在处理速度上有显著提升,比前代模型快3倍,并且在推理过程中减少了内存消耗。
  4. 低内存消耗:与PTv2相比,PTv3的内存消耗减少了10倍,使其更适合资源受限的环境。

Pasted image 20240914082502
Pasted image 20240914082508

核心思想

Point Transformer V3 (PTv3) 的核心思想是通过简化设计和提升模型的扩展能力,解决3D点云处理中精度与效率之间的权衡问题。它提出模型性能更大程度上取决于模型的规模扩展(scaling),而非复杂的设计细节。具体来说,PTv3 的核心思想可以归纳为以下几个方面:
Pasted image 20240914082711

优先考虑效率与扩展性

PTv3 优先简化模型中的复杂机制,以实现更好的扩展性。论文提出,与其在局部设计上过度追求精度,不如通过减少计算复杂度和内存消耗来扩大感受野和模型的规模。这种方法允许模型能够处理更大的点云数据,达到更好的全局性能。

舍弃复杂机制,采用高效的替代方案

PTv3 去掉了在前代模型中占用大量计算时间和内存的复杂机制,比如:

  • KNN 近邻搜索:占用大量计算时间,PTv3 使用了序列化邻域映射替代了精确的近邻搜索。
  • 相对位置编码:这在以前的模型中占用了大量计算资源,PTv3 则采用了更简单的稀疏卷积预处理层来替代。
    Pasted image 20240914083341
    图 补丁分组。 (a) 根据从特定序列化模式导出的顺序对点云进行重新排序。 (b) 通过借用相邻补丁的点来填充点云序列,以确保它可以被指定的补丁大小整除。
    Pasted image 20240914083350
    图 补丁交互。 (a) 标准斑块分组,具有规则的、非移位的排列; (b) 平移扩张,其中点按规则间隔分组,产生扩张效果; © Shift Patch,采用类似于移位窗口方法的移位机制; (d) Shift Order,其中不同的序列化模式被循环分配给连续的注意力层; (d) 洗牌顺序,序列化模式的序列在输入到注意层之前被随机化。

序列化点云处理

PTv3 引入了点云序列化的概念,将原本无序的3D点云数据转换为有序的结构。通过利用空间填充曲线(如 Z-order 和 Hilbert 曲线),它将点云数据转换为一维序列,进而优化处理过程。这种方法有效地保留了空间邻近性,同时大幅提高了处理效率。
Pasted image 20240914082521
图 补丁交互。 (a) 标准斑块分组,具有规则的、非移位的排列; (b) 平移扩张,其中点按规则间隔分组,产生扩张效果; © Shift Patch,采用类似于移位窗口方法的移位机制; (d) Shift Order,其中不同的序列化模式被循环分配给连续的注意力层; (d) 洗牌顺序,序列化模式的序列在输入到注意层之前被随机化。

扩大感受野

PTv3 强调模型在大规模数据下的表现,成功将感受野从16个点扩展到1024个点,这极大地提升了模型的全局感知能力。而通过简化设计,它不仅能够扩大感知范围,还能保持较高的效率。

性能与效率并重

PTv3 在保持高效性的同时,通过扩展感受野和采用更高效的注意力机制,取得了在多个3D感知任务上的最佳性能。相比前代模型,PTv3 的推理速度提高了约3倍,内存消耗减少了10倍,这使得它能够在实际应用中实现更快的处理和更广泛的部署。

核心代码实现讲解

Point Transformer V3 的核心思想是通过简化3D点云处理中的复杂机制,使用序列化的方式来增强模型的效率和扩展性。这一设计基于点云数据的序列化处理,同时采用稀疏卷积注意力机制来有效提升模型的处理速度、内存效率以及感知能力。以下是核心思想结合代码的讲解:

序列化点云处理

Point Transformer V3 通过将点云数据序列化(serialization)转变成有序的结构来处理点云。传统的点云数据无序,无法直接应用卷积等操作,而通过将点云数据按照特定的空间曲线(如Z-order或Hilbert曲线)排序,可以将其转化为有序结构,提升处理效率。

代码实现

  • Point 类中的 serialization 方法中,通过计算点的 grid_coord(网格坐标),并基于输入参数 order(如 z-order 或 Hilbert)生成点云的 serialized_codeserialized_orderserialized_inverse,实现点云的序列化处理。
def serialization(self, order="z", depth=None, shuffle_orders=False):
    # 根据点云的坐标和批次信息进行序列化编码
    code = [
        encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
    ]
    # 序列化顺序和反序列化
    self["serialized_code"] = code
    self["serialized_order"] = order
    self["serialized_inverse"] = inverse

稀疏卷积和特征稀疏化

由于点云数据的稀疏性,Point Transformer V3 使用了稀疏卷积SparseConv),该方法允许仅对有效的点进行卷积运算,而不是对整个三维空间进行操作。这样可以大幅降低计算复杂度,提高内存效率。

代码实现

  • Point 类中的 sparsify 方法负责将点云特征和稀疏卷积张量(SparseConvTensor)关联起来。稀疏卷积操作通过生成 sparse_shapesparse_conv_feat 完成。
def sparsify(self, pad=96):
    sparse_conv_feat = spconv.SparseConvTensor(
        features=self.feat,
        indices=torch.cat([self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1).contiguous(),
        spatial_shape=sparse_shape,
        batch_size=self.batch[-1].tolist() + 1,
    )
    self["sparse_conv_feat"] = sparse_conv_feat

自注意力机制

Point Transformer V3 使用了序列化注意力机制SerializedAttention)来处理点云中的局部和全局关系。该注意力机制通过将点云数据划分为多个片段(patch),对每个片段执行注意力操作,同时利用序列化的顺序来实现快速的局部关系建模。

代码实现

  • SerializedAttention 类中,模型根据输入点云的特征计算 qkv 三个矩阵(query、key、value),并通过点的相对位置关系来进行加权。
  • 如果启用了 flash_attn,则可以进一步加速注意力操作。
def forward(self, point):
    qkv = self.qkv(point.feat)[order]
    if not self.enable_flash:
        q, k, v = qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
        attn = self.softmax(attn)
        feat = (attn @ v).transpose(1, 2).reshape(-1, C)
    else:
        feat = flash_attn.flash_attn_varlen_qkvpacked_func(...)
    return point

模块化与层次化的设计

Point Transformer V3 通过模块化的设计(如PointModulePointSequential),使模型可以灵活地构建多层次的编码器和解码器结构,逐步提取不同尺度下的点云特征。这一设计让模型能够在复杂的三维场景中进行多级处理,并且可以很好地适应不同任务。

代码实现

  • 通过PointSequential类实现模块的序列化执行,并支持自动加载不同的模块(如稀疏卷积、注意力机制等)。
  • PointTransformerV3 类中,定义了编码器和解码器的结构,编码器主要用于提取高层次特征,解码器则将这些特征逐步还原到原始空间中。
class PointTransformerV3(PointModule):
    def __init__(self, in_channels=6, ...):
        self.enc = PointSequential()
        for s in range(self.num_stages):
            enc = PointSequential()
            if s > 0:
                enc.add(SerializedPooling(...))
            for i in range(enc_depths[s]):
                enc.add(Block(...))
            self.enc.add(module=enc, name=f"enc{s}")

如何改进PointNet++

Point Transformer V3 (PTv3) 的核心思想应用到 PointNet++ 中,可以帮助改进后者的局部特征学习、全局建模能力以及处理效率。以下是几种改进方向以及实现步骤,详细讲述如何将 PTv3 的技术融入到 PointNet++ 中。

改进局部特征学习:引入自注意力机制

PointNet++ 在特征提取时,使用了局部的特征聚合方法,比如基于点的邻域进行特征学习。你可以通过引入 PTv3 的自注意力机制,替换 PointNet++ 中的简单特征聚合操作,从而改进局部特征的学习能力。

实现步骤:
  1. 在每个局部区域中,使用 PTv3 的 Serialized Attention 替换 PointNet++ 中的 MLP 或简单卷积操作。
  2. 序列化每个局部区域中的点(如基于 Hilbert 曲线),然后在这些有序点中计算 Query、Key、Value,执行局部的注意力机制,以捕捉点之间的关系。
  3. 对每个区域生成的特征进行注意力加权,并使用投影层重新组合特征。
示例代码片段:
# 局部区域内特征提取替换为自注意力
class LocalAttentionFeatureExtractor(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size, num_heads):
        super().__init__()
        self.attn = SerializedAttention(
            channels=in_channels,
            num_heads=num_heads,
            patch_size=patch_size,
            qkv_bias=True
        )

    def forward(self, point_cloud):
        point_cloud = self.attn(point_cloud)
        return point_cloud

解释

  • 这个 LocalAttentionFeatureExtractor 类使用 PTv3 的 SerializedAttention 模块来提取局部区域内的特征,相比 PointNet++ 原有的 MLP 操作,能够更好地捕捉点的局部关系。

扩展感受野:扩大感受野和全局特征聚合

PointNet++ 的局部特征提取机制(通过球形区域搜索或 KNN 选择邻域)在全局建模时可能不够充分。通过引入 PTv3 的 全局自注意力机制,可以实现更大的感受野,让模型更好地捕捉全局信息。

实现步骤:
  1. 在多层次特征提取的过程中,使用 PTv3 的 Multi-scale Attention,结合 PointNet++ 的多层级结构,将不同尺度的局部区域特征合并,并对整个点云进行全局自注意力计算。
  2. 通过扩大感受野,让模型能够学习到全局的上下文信息,而不是仅限于局部的邻域。
示例代码片段:
class GlobalFeatureExtractor(nn.Module):
    def __init__(self, in_channels, num_heads, patch_size):
        super().__init__()
        self.global_attention = SerializedAttention(
            channels=in_channels,
            num_heads=num_heads,
            patch_size=patch_size,
            enable_flash=True
        )

    def forward(self, point_cloud):
        point_cloud = self.global_attention(point_cloud)
        return point_cloud

解释

  • 通过 GlobalFeatureExtractor 类,你可以将全局自注意力机制应用于整个点云,扩展模型的感受野,并且通过 Flash Attention 提升计算效率。

提升效率:使用稀疏卷积(Sparse Convolution)和序列化处理

PointNet++ 在点云处理时,通常会用全局的 MLP 或基于 KNN 的操作,这些操作可能在大规模点云上效率不高。通过引入 PTv3 的 稀疏卷积(Sparse Convolution)序列化处理,可以显著提升计算效率,减少内存消耗。

实现步骤:
  1. 在 PointNet++ 的每一层采样阶段,使用 PTv3 的 sparsify 方法将点云转化为稀疏卷积张量,减少冗余计算。
  2. 使用 PTv3 的序列化方法,将点云数据通过空间填充曲线进行排序,并应用稀疏卷积和注意力机制以优化计算。
示例代码片段:
class SparseFeatureExtractor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.sparse_conv = spconv.SparseConv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )

    def forward(self, point_cloud):
        point_cloud.sparsify()
        point_cloud = self.sparse_conv(point_cloud.sparse_conv_feat)
        return point_cloud

解释

  • 使用 SparseFeatureExtractor 类将点云数据转换为稀疏卷积形式,并进行高效的卷积操作。相比 PointNet++ 的全连接操作,稀疏卷积在大规模点云上具有更好的效率。

结合多尺度特征:融合局部和全局特征

通过结合 局部自注意力全局自注意力,可以让模型同时学习到局部特征的细节和全局的上下文信息,从而提升对复杂几何结构的感知能力。

实现步骤:
  1. 在 PointNet++ 的多层特征提取模块中,使用 PTv3 的局部注意力模块提取邻域内特征,再使用全局自注意力模块获取全局信息。
  2. 将多尺度的局部特征和全局特征融合起来,并在模型的后续层进行进一步的处理。
示例代码片段:
class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, local_in_channels, global_in_channels, out_channels):
        super().__init__()
        self.local_attention = LocalAttentionFeatureExtractor(local_in_channels, out_channels, patch_size=128, num_heads=4)
        self.global_attention = GlobalFeatureExtractor(global_in_channels, num_heads=8, patch_size=1024)

    def forward(self, local_point_cloud, global_point_cloud):
        local_features = self.local_attention(local_point_cloud)
        global_features = self.global_attention(global_point_cloud)
        fused_features = torch.cat([local_features.feat, global_features.feat], dim=1)
        return fused_features

解释

  • MultiScaleFeatureFusion 类实现了局部特征和全局特征的融合。通过将局部特征和全局特征拼接后,模型能够更好地理解点云的全貌以及细节信息。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值