Transformerv2 code(仅为个人理解)

1、ResNetFusion

ResNetFusion 类是一个基于 ResNet 的特征融合模块,主要用于将多通道的输入特征融合成单通道输出。它继承自 BaseModule,并且使用了一些特定于 MMDetection 和 MMCV 的组件。下面将详细介绍这个类的各个部分。

1.1 类的定义和初始化

class ResNetFusion(BaseModule):
    def __init__(self, in_channels, out_channels, inter_channels, num_layer, norm_cfg=dict(type='SyncBN'),
                 with_cp=False):
        super(ResNetFusion, self).__init__()
        layers = []
        self.inter_channels = inter_channels
        for i in range(num_layer):
            if i == 0:
                if inter_channels == in_channels:
                    layers.append(BasicBlock(in_channels, inter_channels, stride=1, norm_cfg=norm_cfg))
                else:
                    downsample = nn.Sequential(
                        build_conv_layer(None, in_channels, inter_channels, 3, stride=1, padding=1, dilation=1,
                                         bias=False),
                        build_norm_layer(norm_cfg, inter_channels)[1])
                    layers.append(
                        BasicBlock(in_channels, inter_channels, stride=1, norm_cfg=norm_cfg, downsample=downsample))
            else:
                layers.append(BasicBlock(inter_channels, inter_channels, stride=1, norm_cfg=norm_cfg))
        self.layers = nn.Sequential(*layers)
        self.layer_norm = nn.Sequential(
                nn.Linear(inter_channels, out_channels),
                nn.LayerNorm(out_channels))
        self.with_cp = with_cp

1.1.1 参数说明:

  • in_channels:输入特征的通道数。
  • out_channels:输出特征的通道数。
  • inter_channels:中间层的通道数。
  • num_layer:ResNet 基本块的层数。
  • norm_cfg:归一化层的配置,默认为同步批归一化(SyncBN)。
  • with_cp:是否使用检查点来节省内存。

1.1.2 初始化过程:

  1. 定义中间通道数: 存储 inter_channels

  2. 构建 ResNet 层: 根据 num_layer 构建多个 ResNet 基本块(BasicBlock),第一个块可能包含降采样(downsample),用于将 in_channels 转换为 inter_channels

  3. 定义层归一化: 使用线性层和层归一化层将中间特征转换为输出特征。

  4. 存储检查点标志: 存储是否使用检查点(with_cp)的标志。

1.2 前向传播函数

def forward(self, x):
    x = torch.cat(x, 1).contiguous()
    # x should be [1, in_channels, bev_h, bev_w]
    for lid, layer in enumerate(self.layers):
        if self.with_cp and x.requires_grad:
            x = checkpoint.checkpoint(layer, x)
        else:
            x = layer(x)
    x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)  # nchw -> n(hw)c
    x = self.layer_norm(x)
    return x

1.2.1 前向传播过程:

  1. 特征拼接: 将输入的多个特征在通道维度上拼接起来。输入 x 应该是一个包含多个特征图的列表,通过 torch.cat 函数将它们拼接成一个张量。

  2. 遍历 ResNet 层: 通过循环遍历 self.layers 中的每一层 BasicBlock,逐层处理特征图。如果设置了 with_cp 并且输入张量需要梯度,则使用检查点来节省内存。

  3. 特征重塑: 将输出特征重塑为 [batch_size, num_channels, height * width] 的形状,并进行转置,使其形状变为 [batch_size, height * width, num_channels]

  4. 层归一化: 使用 self.layer_norm 对特征进行线性变换和层归一化。

  5. 返回结果: 返回处理后的特征。

1.2.2 使用示例

假设有一组输入特征图,使用这个 ResNetFusion 模块来融合这些特征:

# 假设输入特征图是一个包含多个张量的列表
input_features = [torch.randn(1, 64, 128, 128), torch.randn(1, 64, 128, 128)]

# 创建 ResNetFusion 实例
resnet_fusion = ResNetFusion(in_channels=128, out_channels=256, inter_channels=128, num_layer=3)

# 前向传播
output = resnet_fusion(input_features)
print(output.shape)  # 输出特征的形状

1.3 总结

ResNetFusion 类实现了一个基于 ResNet 的特征融合模块,利用多个 BasicBlock 层来处理和融合输入特征,并通过线性变换和层归一化将其转换为输出特征。该模块支持检查点功能,可以在训练过程中节省内存。通过这种设计,ResNetFusion 可以灵活地处理和融合多通道的输入特征,适用于各种计算机视觉任务中的特征提取和融合操作。

2、PerceptionTransformerBEVEncoder

PerceptionTransformerBEVEncoder 类是一个高级的特征编码器,主要用于处理多相机视角下的特征,并将它们融合到一个鸟瞰图(BEV,Bird's Eye View)中。这类编码器在自动驾驶和其他需要多视角融合的计算机视觉任务中非常常见。

2.1 类定义和初始化

@TRANSFORMER.register_module()
class PerceptionTransformerBEVEncoder(BaseModule):
    def __init__(self,
                 num_feature_levels=4,
                 num_cams=6,
                 two_stage_num_proposals=300,
                 encoder=None,
                 embed_dims=256,
                 use_cams_embeds=True,
                 rotate_center=[100, 100],
                 **kwargs):
        super(PerceptionTransformerBEVEncoder, self).__init__(**kwargs)
        self.encoder = build_transformer_layer_sequence(encoder)
        self.embed_dims = embed_dims
        self.num_feature_levels = num_feature_levels
        self.num_cams = num_cams
        self.fp16_enabled = False

        self.use_cams_embeds = use_cams_embeds

        self.two_stage_num_proposals = two_stage_num_proposals
        self.rotate_center = rotate_center
        """Initialize layers of the Detr3DTransformer."""
        self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims))
        if self.use_cams_embeds:
            self.cams_embeds = nn.Parameter(torch.Tensor(self.num_cams, self.embed_dims))

2.1.1 参数说明:

  • num_feature_levels:特征层级的数量,通常指不同分辨率的特征图数量。
  • num_cams:相机的数量。
  • two_stage_num_proposals:两个阶段中的提案数量,通常用于目标检测中的初始提案数量。
  • encoder:编码器配置,用于构建具体的编码器层序列。
  • embed_dims:嵌入维度,特征在编码器中的表示维度。
  • use_cams_embeds:是否使用相机嵌入信息。
  • rotate_center:旋转中心,用于图像的旋转操作。
  • kwargs:其他可能的参数。

2.1.2 初始化过程:

  1. 初始化父类: 使用 super 调用 BaseModule 的初始化方法,传递其他参数(**kwargs)。

  2. 构建编码器: 使用 build_transformer_layer_sequence 函数根据 encoder 配置构建一个 Transformer 编码器序列。

  3. 设置嵌入维度: 存储 embed_dimsnum_feature_levelsnum_cams 等参数,用于后续操作。

  4. 是否启用相机嵌入: 根据 use_cams_embeds 参数决定是否初始化相机嵌入。

  5. 初始化嵌入参数

    • level_embeds:用于不同特征层级的嵌入。
    • cams_embeds:用于不同相机的嵌入(如果 use_cams_embedsTrue)。

2.2 权重初始化

def init_weights(self):
    """Initialize the transformer weights."""
    for p in self.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    for m in self.modules():
        if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
                or isinstance(m, CustomMSDeformableAttention):
            try:
                m.init_weight()
            except AttributeError:
                m.init_weights()
    normal_(self.level_embeds)
    if self.use_cams_embeds:
        normal_(self.cams_embeds)

2.2.1 初始化过程:

  1. 参数初始化: 对所有参数进行初始化,如果参数的维度大于 1(例如,矩阵),则使用 Xavier 均匀分布进行初始化。

  2. 模块权重初始化: 对于特定类型的模块(MSDeformableAttention3DTemporalSelfAttentionCustomMSDeformableAttention),调用其 init_weightinit_weights 方法进行权重初始化。

  3. 嵌入初始化: 使用标准正态分布(normal_)对 level_embedscams_embeds 进行初始化。

2.3 前向传播函数

def forward(self,
            mlvl_feats,
            bev_queries,
            bev_h,
            bev_w,
            grid_length=[0.512, 0.512],
            bev_pos=None,
            prev_bev=None,
            **kwargs):
    """
    obtain bev features.
    """
    bs = mlvl_feats[0].size(0)
    bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
    bev_pos = bev_pos.flatten(2).permute(2, 0, 1)

    feat_flatten = []
    spatial_shapes = []
    for lvl, feat in enumerate(mlvl_feats):
        bs, num_cam, c, h, w = feat.shape
        spatial_shape = (h, w)
        feat = feat.flatten(3).permute(1, 0, 3, 2)
        if self.use_cams_embeds:
            feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
        feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(feat.dtype)
        spatial_shapes.append(spatial_shape)
        feat_flatten.append(feat)

    feat_flatten = torch.cat(feat_flatten, 2)
    spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=bev_pos.device)
    level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

    feat_flatten = feat_flatten.permute(0, 2, 1, 3)  # (num_cam, H*W, bs, embed_dims)

    bev_embed = self.encoder(bev_queries,
                             feat_flatten,
                             feat_flatten,
                             bev_h=bev_h,
                             bev_w=bev_w,
                             bev_pos=bev_pos,
                             spatial_shapes=spatial_shapes,
                             level_start_index=level_start_index,
                             prev_bev=None,
                             shift=bev_queries.new_tensor([0, 0]).unsqueeze(0),
                             **kwargs)
    # rotate current bev to final aligned
    prev_bev = bev_embed
    if 'aug_param' in kwargs['img_metas'][0] and 'GlobalRotScaleTransImage_param' in kwargs['img_metas'][0][
        'aug_param']:
        rot_angle, scale_ratio, flip_dx, flip_dy, bda_mat, only_gt = kwargs['img_metas'][0]['aug_param'][
            'GlobalRotScaleTransImage_param']
        prev_bev = prev_bev.reshape(bs, bev_h, bev_w, -1).permute(0, 3, 1, 2)  # bchw
        if only_gt:
            # rot angle
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(0.5, bev_h - 0.5, bev_h, dtype=bev_queries.dtype, device=bev_queries.device),
                torch.linspace(0.5, bev_w - 0.5, bev_w, dtype=bev_queries.dtype, device=bev_queries.device))
            ref_y = (ref_y / bev_h)
            ref_x = (ref_x / bev_w)
            grid = torch.stack((ref_x, ref_y), -1)
            grid_shift = grid * 2.0 - 1.0
            grid_shift = grid_shift.unsqueeze(0).unsqueeze(-1)
            bda_mat = bda_mat[:2, :2].to(grid_shift).view(1, 1, 1, 2, 2).repeat(grid_shift.shape[0],
                                                                                grid_shift.shape[1],
                                                                                grid_shift.shape[2], 1, 1)
            grid_shift = torch.matmul(bda_mat, grid_shift).squeeze(-1)
            prev_bev = torch.nn.functional.grid_sample(prev_bev, grid_shift, align_corners=False)
        prev_bev = prev_bev.reshape(bs, -1, bev_h * bev_w)
        prev_bev = prev_bev.permute(0, 2, 1)
    return prev_bev

2.3.1 参数说明:

  • mlvl_feats:多层次特征图的列表。
  • bev_queries:BEV 查询特征。
  • bev_h:BEV 特征图的高度。
  • bev_w:BEV 特征图的宽度。
  • grid_length:BEV 网格的长度。
  • bev_pos:BEV 位置编码。
  • prev_bev:上一个时间步的 BEV 特征。
  • **kwargs:其他关键字参数。

2.3.2 前向传播过程:

  1. 初始化

    • 获取批次大小 bs
    • bev_queries 进行维度扩展和重复,使其与批次大小匹配。
    • bev_pos 展平并进行维度转换。
  2. 处理多层次特征图

    • 遍历 mlvl_feats 中的每个特征图,对其进行展平和转置。
    • 添加摄像机嵌入和特征层嵌入。
    • 将处理后的特征添加到 feat_flatten 列表,并记录每个特征图的空间形状。
  3. 特征展平和空间形状索引

    • 将所有处理后的特征拼接成一个大张量 feat_flatten
    • 将空间形状转换为张量,并计算每层特征的起始索引 level_start_index
  4. 调用编码器

    • 使用 self.encoderbev_queries 进行编码,生成 bev_embed
  5. 旋转和对齐 BEV 特征

    • 检查是否有图像增强参数,如果有,则对 BEV 特征进行相应的旋转和对齐。
    • 使用网格采样 torch.nn.functional.grid_sample 对特征进行调整。
  6. 返回结果

    • 返回调整后的 BEV 特征 prev_bev

2.4 forward() 详细解释

前向传播函数 forwardPerceptionTransformerBEVEncoder 类的核心部分,它将多视角的特征图(mlvl_feats)和 BEV 查询(bev_queries)处理成统一的 BEV 表示(prev_bev)。

def forward(self,
            mlvl_feats,
            bev_queries,
            bev_h,
            bev_w,
            grid_length=[0.512, 0.512],
            bev_pos=None,
            prev_bev=None,
            **kwargs):
    """
    obtain bev features.
    """

2.4.1 参数详细说明:

  • mlvl_feats:多层次的特征图列表,每个元素代表不同分辨率的特征图,维度为 [batch_size, num_cams, channels, height, width]。这些特征图来自不同的相机视角。
  • bev_queries:BEV(Bird's Eye View)查询特征,维度为 [num_queries, embed_dims]。这些查询特征是用于从特征图中提取信息的。
  • bev_hbev_w:BEV 特征图的高度和宽度。
  • grid_length:网格的尺寸,表示每个 BEV 单元格的大小,默认值为 [0.512, 0.512]
  • bev_pos:BEV 位置编码,维度为 [batch_size, embed_dims, bev_h, bev_w],用于在 BEV 特征图上添加位置信息。
  • prev_bev:前一个时间步的 BEV 特征,主要用于时间序列的 BEV 特征融合。
  • **kwargs:其他关键字参数,通常用于传递额外的信息,如图像的元数据(img_metas)等。

2.4.2 前向传播步骤

1. 初始化 BEV 查询和 BEV 位置编码
bs = mlvl_feats[0].size(0)  # 获取批次大小
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)  # 扩展并重复 bev_queries,使其与批次大小匹配
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)  # 调整 bev_pos 的形状,使其适应后续操作
  • 获取批次大小:从 mlvl_feats 的第一个元素中提取批次大小 bs。假设 mlvl_feats[0] 的维度为 [batch_size, num_cams, channels, height, width],则 bs = mlvl_feats[0].size(0) 提取批次大小。

  • 扩展和复制 BEV 查询:将 bev_queries 的形状从 [num_queries, embed_dims] 扩展到 [num_queries, 1, embed_dims],然后在批次维度上复制,使其形状变为 [num_queries, batch_size, embed_dims]

  • 调整 BEV 位置编码:将 bev_pos 展平(将高度和宽度展平成一个维度),使其形状变为 [batch_size, embed_dims, bev_h * bev_w],然后转换维度顺序,使其形状变为 [bev_h * bev_w, batch_size, embed_dims]

2. 处理多层次的特征图
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
    bs, num_cam, c, h, w = feat.shape  # 获取特征图的维度信息
    spatial_shape = (h, w)  # 记录特征图的空间形状
    feat = feat.flatten(3).permute(1, 0, 3, 2)  # 将特征图展平成 [num_cam, batch_size, h*w, channels]
    if self.use_cams_embeds:
        feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)  # 添加摄像机嵌入
    feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(feat.dtype)  # 添加特征层嵌入
    spatial_shapes.append(spatial_shape)  # 将空间形状添加到列表
    feat_flatten.append(feat)  # 将展平的特征图添加到列表
  • 初始化列表:创建 feat_flattenspatial_shapes 两个列表,用于存储展平后的特征图和空间形状。
  • 遍历特征层:对于每个层级的特征图,提取其维度信息并展平到第三维度 [h * w],然后转置,使其变为 [num_cam, batch_size, h * w, channels]
  • 添加摄像机嵌入:如果 use_cams_embedsTrue,将摄像机嵌入添加到展平的特征图中。
  • 添加特征层嵌入:将特征层嵌入添加到展平的特征图中。
  • 存储结果:将每个特征层的空间形状和展平的特征图分别存储到 spatial_shapesfeat_flatten 列表中。
3. 合并特征图和计算空间索引
feat_flatten = torch.cat(feat_flatten, 2)  # 将展平的特征图在第二个维度(h*w)上拼接
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=bev_pos.device)  # 转换空间形状为张量
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
  • 拼接特征图:将 feat_flatten 列表中的展平特征图在 h*w 维度上拼接,形成一个大的特征图张量。
  • 转换空间形状:将 spatial_shapes 列表转换为张量,以便在后续计算中使用。
  • 计算起始索引:计算每个特征层在拼接后的大特征图中的起始索引,用于定位特定层的特征。
4. 准备输入给编码器
feat_flatten = feat_flatten.permute(0, 2, 1, 3)  # 将展平的特征图重新排列维度,使其变为 [num_cam, H*W, batch_size, embed_dims]
5. 调用编码器进行特征融合
bev_embed = self.encoder(bev_queries,
                         feat_flatten,
                         feat_flatten,
                         bev_h=bev_h,
                         bev_w=bev_w,
                         bev_pos=bev_pos,
                         spatial_shapes=spatial_shapes,
                         level_start_index=level_start_index,
                         prev_bev=None,
                         shift=bev_queries.new_tensor([0, 0]).unsqueeze(0),
                         **kwargs)
  • 调用编码器:使用 self.encoderbev_queries 进行编码,融合 feat_flatten 中的特征,生成 BEV 嵌入 bev_embed
  • 传递参数:将处理后的 bev_queries、展平的特征图 feat_flatten、BEV 尺寸 bev_hbev_w、BEV 位置编码 bev_pos、空间形状 spatial_shapes、层级起始索引 level_start_index 等传递给编码器。
6. 对齐和旋转 BEV 特征
prev_bev = bev_embed
if 'aug_param' in kwargs['img_metas'][0] and 'GlobalRotScaleTransImage_param' in kwargs['img_metas'][0]['aug_param']:
    rot_angle, scale_ratio, flip_dx, flip_dy, bda_mat, only_gt = kwargs['img_metas'][0]['aug_param'][
        'GlobalRotScaleTransImage_param']
    prev_bev = prev_bev.reshape(bs, bev_h, bev_w, -1).permute(0, 3, 1, 2)  # bchw
    if only_gt:
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, bev_h - 0.5, bev_h, dtype=bev_queries.dtype, device=bev_queries.device),
            torch.linspace(0.5, bev_w - 0.5, bev_w, dtype=bev_queries.dtype, device=bev_queries.device))
        ref_y = (ref_y / bev_h)
        ref_x = (ref_x / bev_w)
        grid = torch.stack((ref_x, ref_y), -1)
        grid_shift = grid * 2.0 - 1.0
        grid_shift = grid_shift.unsqueeze(0).unsqueeze(-1)
        bda_mat = bda_mat[:2, :2].to(grid_shift).view(1, 1, 1, 2, 2).repeat(grid_shift.shape[0],
                                                                            grid_shift.shape[1],
                                                                            grid_shift.shape[2], 1, 1)
        grid_shift = torch.matmul(bda_mat, grid_shift).squeeze(-1)
        prev_bev = torch.nn.functional.grid_sample(prev_bev, grid_shift, align_corners=False)
    prev_bev = prev_bev.reshape(bs, -1, bev_h * bev_w)
    prev_bev = prev_bev.permute(0, 2, 1)
  1. 检查增强参数:如果存在图像增强参数 aug_param,则进行相应的处理。
  2. 提取旋转和缩放信息
    • 旋转角度 (rot_angle)缩放比例 (scale_ratio):用于调整特征图。
    • 翻转信息 (flip_dxflip_dy):用于水平和垂直翻转。
    • 变换矩阵 (bda_mat):用于几何变换。
    • 仅调整 Ground Truth (only_gt):指示是否只调整 Ground Truth。
  3. 调整特征形状:将 prev_bev 重塑为 [batch_size, embed_dims, bev_h, bev_w] 并转置为 [batch_size, embed_dims, bev_h, bev_w]
  4. 创建网格
    • 使用 torch.meshgrid 创建一个二维网格,用于定义采样点的位置。
    • 调整网格坐标范围到 [-1, 1],以适应 grid_sample 函数的要求。
  5. 应用旋转和缩放
    • 使用变换矩阵 bda_mat 对网格进行线性变换。
    • 使用 torch.nn.functional.grid_sample 根据调整后的网格对特征图进行采样,完成旋转和缩放的变换。
  6. 恢复特征形状:将 prev_bev 重塑回原始的形状 [batch_size, bev_h * bev_w, embed_dims]
7. 返回结果
return prev_bev

最后,返回经过编码和调整的 BEV 特征 prev_bev

2.4.3 总结

forward 方法通过以下几个步骤将多层次的多视角特征图处理成一个统一的 BEV 特征表示:

  1. 初始化 BEV 查询和位置:将 BEV 查询扩展到批次大小,将 BEV 位置展平并调整维度。
  2. 处理多层次特征图:展平每层的特征图,添加摄像机和层级嵌入。
  3. 合并特征图和计算空间索引:将特征图在第二个维度上拼接,计算每层特征在拼接后的大特征图中的起始索引。
  4. 准备输入给编码器:调整特征图维度以适应编码器的输入要求。
  5. 调用编码器进行特征融合:通过编码器对 bev_queries 进行编码,融合特征图。
  6. 对齐和旋转 BEV 特征:检查并应用增强参数,对 BEV 特征进行对齐和旋转。
  7. 返回最终的 BEV 特征:返回经过处理的 BEV 特征 prev_bev

这个过程使得来自多个摄像头的视角信息能够有效地整合到一个统一的 BEV 特征图中,为后续的检测和识别提供一致的特征表示。

2.5 类的作用和应用场景

PerceptionTransformerBEVEncoder 类的主要作用是在多视角特征图和 BEV 查询之间进行特征融合和编码,将摄像机视角的特征图转换成统一的 BEV 表示。这在自动驾驶和 3D 物体检测中非常有用,因为它能够整合来自多个摄像机的视角信息,为后续的检测和识别提供更统一的特征表示。

2.5.1 实际使用示例

假设有一组多视角特征图,使用 PerceptionTransformerBEVEncoder 来生成 BEV 特征:

# 假设输入特征图和 BEV 查询是下面这样的张量
mlvl_feats = [torch.randn(2, 6, 256, 50, 50), torch.randn(2, 6, 256, 25, 25)]
bev_queries = torch.randn(200, 256)
bev_h = 50
bev_w = 50
bev_pos = torch.randn(1, 256, 50, 50)

# 创建 PerceptionTransformerBEVEncoder 实例
encoder = PerceptionTransformerBEVEncoder(
    encoder=dict(
        type='TransformerEncoder',
        num_layers=6,
        transformerlayers=dict(
            type='BaseTransformerLayer',
            attn_cfgs=dict(
                type='MultiheadAttention', embed_dims=256, num_heads=8),
            ffn_cfgs=dict(
                type='FFN', embed_dims=256, feedforward_channels=1024),
            operation_order=('self_attn', 'norm', 'ffn', 'norm'))
    ),
    num_feature_levels=2,
    num_cams=6,
    embed_dims=256
)

# 前向传播
output = encoder(mlvl_feats, bev_queries, bev_h, bev_w, bev_pos=bev_pos)
print(output.shape)  # 输出 BEV 特征的形状

在这个示例中创建了一个 PerceptionTransformerBEVEncoder 实例,并使用其处理一组多视角特征图和 BEV 查询,生成 BEV 特征。

3、PerceptionTransformerV2

PerceptionTransformerV2 类继承自 PerceptionTransformerBEVEncoder,它实现了一个多帧、多视角的 BEV(鸟瞰视图)特征融合和 Transformer 编码器,用于 3D 目标检测任务。这个类扩展了基本的 BEV 编码功能,增加了对时间序列的支持和额外的解码步骤。下面是对这个类及其方法的详细解释。

3.1 类的定义和初始化

@TRANSFORMER.register_module()
class PerceptionTransformerV2(PerceptionTransformerBEVEncoder):
    """Implements the Detr3D transformer.
    Args:
        as_two_stage (bool): Generate query from encoder features.
            Default: False.
        num_feature_levels (int): Number of feature maps from FPN:
            Default: 4.
        two_stage_num_proposals (int): Number of proposals when set
            `as_two_stage` as True. Default: 300.
    """
3.1.1 参数说明
  • num_feature_levels:特征层的数量。
  • num_cams:摄像机的数量。
  • two_stage_num_proposals:两阶段检测中提案的数量。
  • encoder:编码器配置。
  • embed_dims:嵌入维度。
  • use_cams_embeds:是否使用摄像机嵌入。
  • rotate_center:旋转中心。
  • frames:用于时间序列信息的帧索引,默认为 (0,),表示当前帧。
  • decoder:解码器配置。
  • num_fusion:用于融合的层数。
  • inter_channels:中间通道数。
  • **kwargs:其他关键字参数。

3.2 初始化函数

def __init__(self,
             num_feature_levels=4,
             num_cams=6,
             two_stage_num_proposals=300,
             encoder=None,
             embed_dims=256,
             use_cams_embeds=True,
             rotate_center=[100, 100],
             frames=(0,),
             decoder=None,
             num_fusion=3,
             inter_channels=None,
             **kwargs):
    super(PerceptionTransformerV2, self).__init__(num_feature_levels, num_cams, two_stage_num_proposals, encoder,
                                                  embed_dims, use_cams_embeds, rotate_center,
                                                  **kwargs)
    self.decoder = build_transformer_layer_sequence(decoder)
    self.reference_points = nn.Linear(self.embed_dims, 3)
    self.frames = frames
    if len(self.frames) > 1:
        self.fusion = ResNetFusion(len(self.frames) * self.embed_dims, self.embed_dims,
                                   inter_channels if inter_channels is not None else len(
                                       self.frames) * self.embed_dims,
                                   num_fusion)
初始化过程
  1. 调用父类的初始化:首先调用 PerceptionTransformerBEVEncoder 的初始化方法,设置基本参数和层。
  2. 初始化解码器:使用 build_transformer_layer_sequence 方法构建解码器层序列。
  3. 初始化参考点:定义一个线性层 reference_points,用于生成参考点的初始值。
  4. 设置时间序列帧:存储帧索引 frames
  5. 初始化特征融合模块:如果 frames 的长度大于 1,创建一个 ResNetFusion 实例用于多帧特征融合。

3.3 初始化权重方法

def init_weights(self):
    """Initialize the transformer weights."""
    super().init_weights()
    for p in self.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    for m in self.modules():
        if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
                or isinstance(m, CustomMSDeformableAttention):
            try:
                m.init_weight()
            except AttributeError:
                m.init_weights()
    xavier_init(self.reference_points, distribution='uniform', bias=0.)
权重初始化过程
  1. 调用父类的初始化:首先调用父类的 init_weights 方法初始化基本权重。
  2. 遍历所有参数:对所有参数进行 Xavier 均匀分布初始化。
  3. 初始化特定模块:对于特定类型的模块(如 MSDeformableAttention3D),调用它们的 init_weight 方法。
  4. 初始化参考点:使用 Xavier 均匀分布初始化 reference_points

3.4 获取 BEV 特征的方法

def get_bev_features(
        self,
        mlvl_feats,
        bev_queries,
        bev_h,
        bev_w,
        grid_length=[0.512, 0.512],
        bev_pos=None,
        prev_bev=None,
        **kwargs):
    return super().forward(
        mlvl_feats,
        bev_queries,
        bev_h,
        bev_w,
        grid_length,
        bev_pos,
        prev_bev,
        **kwargs
    )
作用

这个方法只是简单地调用父类 PerceptionTransformerBEVEncoderforward 方法来获取 BEV 特征。它将参数传递给父类方法,然后返回生成的 BEV 特征。

3.5 前向传播方法

def forward(self,
            mlvl_feats,
            bev_queries,
            object_query_embed,
            bev_h,
            bev_w,
            grid_length=[0.512, 0.512],
            bev_pos=None,
            reg_branches=None,
            cls_branches=None,
            prev_bev=None,
            **kwargs):
    """
    Forward function for `Detr3DTransformer`.
    """
    bev_embed = self.get_bev_features(
        mlvl_feats,
        bev_queries,
        bev_h,
        bev_w,
        grid_length=grid_length,
        bev_pos=bev_pos,
        prev_bev=None,
        **kwargs)  # bev_embed shape: bs, bev_h*bev_w, embed_dims
3.5.1 作用
1、获取 BEV 特征

 调用 get_bev_features 方法获取 BEV 特征,bev_embed 的形状为 [batch_size, bev_h * bev_w, embed_dims]

    if len(self.frames) > 1:
        cur_ind = list(self.frames).index(0)
        assert prev_bev[cur_ind] is None and len(prev_bev) == len(self.frames)
        prev_bev[cur_ind] = bev_embed

        # fill prev frame feature 
        for i in range(1, cur_ind + 1):
            if prev_bev[cur_ind - i] is None:
                prev_bev[cur_ind - i] = prev_bev[cur_ind - i + 1].detach()

        # fill next frame feature 
        for i in range(cur_ind + 1, len(self.frames)):
            if prev_bev[i] is None:
                prev_bev[i] = prev_bev[i - 1].detach()
        bev_embed = [x.reshape(x.shape[0], bev_h, bev_w, x.shape[-1]).permute(0, 3, 1, 2).contiguous() for x in
                     prev_bev]
        bev_embed = self.fusion(bev_embed)
2、多帧特征融合
  • 当前帧索引:找到 frames 中当前帧 0 的索引。
  • 设置当前帧的 BEV 特征:将 prev_bev 中当前帧的 BEV 特征设置为 bev_embed
  • 填充前帧特征:填充前帧的 BEV 特征,如果为空,则使用后一帧的特征进行填充。
  • 填充后帧特征:填充后帧的 BEV 特征,如果为空,则使用前一帧的特征进行填充。
  • 调整 BEV 特征的形状:将 prev_bev 中每个帧的 BEV 特征调整为 [batch_size, embed_dims, bev_h, bev_w] 并转换为 [batch_size, embed_dims, bev_h, bev_w]
  • 融合多帧特征:使用 self.fusion 对多帧特征进行融合,得到最终的 BEV 特征。
    bs = mlvl_feats[0].size(0)
    query_pos, query = torch.split(
        object_query_embed, self.embed_dims, dim=1)
    query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
    query = query.unsqueeze(0).expand(bs, -1, -1)
    reference_points = self.reference_points(query_pos)
    reference_points = reference_points.sigmoid()
    init_reference_out = reference_points
3、准备解码器输入
  1. 获取批次大小: 从 mlvl_feats 的第一个特征图中获取批次大小 bs

  2. 拆分查询位置和查询向量: 使用 torch.splitobject_query_embed 拆分成两个部分,分别为查询位置 query_pos 和查询向量 query。每个部分的维度为 [num_queries, embed_dims],其中 embed_dims 是嵌入维度。

  3. 扩展查询位置和查询向量: 将 query_posquery 在第一个维度上扩展,复制到批次大小 bs。扩展后,它们的形状变为 [batch_size, num_queries, embed_dims]

  4. 生成参考点: 通过 self.reference_points(一个线性层)计算查询位置的参考点 reference_points,并使用 sigmoid 函数将其缩放到 [0, 1] 区间。参考点的形状变为 [batch_size, num_queries, 3]

  5. 初始参考点输出: 将 reference_points 存储为初始参考点输出 init_reference_out,用于后续解码器处理。

4、继续处理 BEV 特征和查询
    query = query.permute(1, 0, 2)
    query_pos = query_pos.permute(1, 0, 2)
    bev_embed = bev_embed.permute(1, 0, 2)

调整维度:

queryquery_posbev_embed 的维度从 [batch_size, num_queries, embed_dims][batch_size, bev_h * bev_w, embed_dims] 转置为 [num_queries, batch_size, embed_dims][bev_h * bev_w, batch_size, embed_dims]。这样可以适应解码器的输入要求。

5、调用解码器进行特征解码 
    inter_states, inter_references = self.decoder(
        query=query,
        key=None,
        value=bev_embed,
        query_pos=query_pos,
        reference_points=reference_points,
        reg_branches=reg_branches,
        cls_branches=cls_branches,
        spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
        level_start_index=torch.tensor([0], device=query.device),
        **kwargs)

调用解码器: 使用 self.decoderquery 进行解码。解码器将 query 作为输入,并利用 bev_embed 作为值 (value) 进行特征解码。

  • query:解码器的查询输入,形状为 [num_queries, batch_size, embed_dims]
  • key:解码器的键输入,这里未使用(为 None)。
  • value:解码器的值输入,这里是 BEV 特征 bev_embed,形状为 [bev_h * bev_w, batch_size, embed_dims]
  • query_pos:查询的位置编码。
  • reference_points:查询的参考点。
  • reg_branches:回归分支,解码器层的回归头。
  • cls_branches:分类分支,解码器层的分类头。
  • spatial_shapes:空间形状,这里是 [bev_h, bev_w]
  • level_start_index:层级起始索引,这里是 [0]
  • **kwargs:其他关键字参数。

解码器输出: 解码器返回两个值:

  • inter_states:解码器的中间状态输出,形状为 [num_dec_layers, batch_size, num_queries, embed_dims][1, batch_size, num_queries, embed_dims],取决于解码器是否返回中间状态。
  • inter_references:解码器中参考点的内部值,形状为 [num_dec_layers, batch_size, num_queries, embed_dims]
6、返回解码器的输出
    inter_references_out = inter_references

    return bev_embed, inter_states, init_reference_out, inter_references_out
  • 存储内部参考点输出: 将 inter_references 存储为 inter_references_out,用于返回。

  • 返回解码器的输出: 返回以下几个部分:

  • bev_embed:解码后的 BEV 特征,形状为 [bev_h * bev_w, batch_size, embed_dims]
  • inter_states:解码器的中间状态输出。
  • init_reference_out:初始参考点的输出。
  • inter_references_out:解码器内部参考点的输出。

3.6 总结

PerceptionTransformerV2 类继承自 PerceptionTransformerBEVEncoder,在其基础上进行了扩展,添加了时间序列融合的能力和一个解码器。前向传播过程包含以下几个步骤:

  1. 获取 BEV 特征: 使用 get_bev_features 方法从多视角特征中提取 BEV 特征。

  2. 多帧特征融合: 如果定义了多个帧索引,使用 ResNetFusion 模块将多帧的 BEV 特征融合在一起。

  3. 准备解码器输入: 拆分和扩展查询向量,并生成参考点。

  4. 调用解码器: 使用解码器对查询进行特征解码,结合 BEV 特征进行进一步处理。

  5. 返回解码器输出: 返回解码后的 BEV 特征、中间状态、初始参考点和内部参考点。

通过这些步骤,PerceptionTransformerV2 可以将来自多个摄像头和时间帧的特征融合到统一的 BEV 表示中,并进行解码,以便进行后续的目标检测或其他任务。

  • 12
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值