TemporalSelfAttention——时空注意力

TemporalSelfAttention 类实现了一个用于 BEVFormer 的时间自注意力模块。这种注意力机制基于 Deformable-Detr 的思路,旨在处理多尺度和多视角的特征,同时结合时间序列信息来增强特征表达。在 BEVFormer 中,时间自注意力模块是关键组件,它能够将当前帧和历史帧的特征进行融合和注意力计算,从而提升模型对动态场景的理解能力。

类的定义和初始化

@ATTENTION.register_module()
class TemporalSelfAttention(BaseModule):
    """An attention module used in BEVFormer based on Deformable-Detr.

    Args:
        embed_dims (int): The embedding dimension of Attention.
            Default: 256.
        num_heads (int): Parallel attention heads. Default: 64.
        num_levels (int): The number of feature map used in
            Attention. Default: 4.
        num_points (int): The number of sampling points for
            each query in each head. Default: 4.
        im2col_step (int): The step used in image_to_column.
            Default: 64.
        dropout (float): A Dropout layer on `inp_identity`.
            Default: 0.1.
        batch_first (bool): Key, Query and Value are shape of
            (batch, n, embed_dim)
            or (n, batch, embed_dim). Default to True.
        norm_cfg (dict): Config dict for normalization layer.
            Default: None.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
        num_bev_queue (int): Number of BEV frames in the queue. Default is 2.
    """

    def __init__(self,
                 embed_dims=256,
                 num_heads=8,
                 num_levels=4,
                 num_points=4,
                 num_bev_queue=2,
                 im2col_step=64,
                 dropout=0.1,
                 batch_first=True,
                 norm_cfg=None,
                 init_cfg=None):
参数说明
  • embed_dims:嵌入维度,默认为 256。
  • num_heads:多头注意力机制中的头数量,默认为 8。
  • num_levels:特征层的数量,默认为 4。
  • num_points:每个查询在每个头中采样的点的数量,默认为 4。
  • num_bev_queue:BEV 队列的长度,表示当前帧和历史帧的数量,默认为 2。
  • im2col_stepim2col 操作的步长,默认为 64。
  • dropout:用于输入的 dropout 比率,默认为 0.1。
  • batch_first:表示输入张量的形状是否为 (batch, n, embed_dim),默认为 True。
  • norm_cfg:归一化层的配置,默认为 None。
  • init_cfg:初始化配置,默认为 None。

初始化过程

  1. 父类初始化:调用 BaseModule 的初始化方法。
  2. 维度检查:确保 embed_dims 能够被 num_heads 整除,否则抛出异常。
  3. 警告非二次幂的维度:如果每个注意力头的维度不是 2 的幂次,会发出警告,因为 CUDA 实现中处理二次幂的维度更有效。
  4. 初始化线性层
    • sampling_offsets:用于计算采样偏移的线性层,输入维度为 embed_dims * num_bev_queue,输出维度为 num_bev_queue * num_heads * num_levels * num_points * 2
    • attention_weights:用于计算注意力权重的线性层,输入维度为 embed_dims * num_bev_queue,输出维度为 num_bev_queue * num_heads * num_levels * num_points
    • value_proj:对值进行投影的线性层。
    • output_proj:对输出进行投影的线性层。
  5. 权重初始化:调用 init_weights 方法进行权重初始化。

权重初始化

    def init_weights(self):
        """Default initialization for Parameters of Module."""
        constant_init(self.sampling_offsets, 0.)
        thetas = torch.arange(
            self.num_heads,
            dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (grid_init /
                     grid_init.abs().max(-1, keepdim=True)[0]).view(
            self.num_heads, 1, 1,
            2).repeat(1, self.num_levels * self.num_bev_queue, self.num_points, 1)

        for i in range(self.num_points):
            grid_init[:, :, i, :] *= i + 1

        self.sampling_offsets.bias.data = grid_init.view(-1)
        constant_init(self.attention_weights, val=0., bias=0.)
        xavier_init(self.value_proj, distribution='uniform', bias=0.)
        xavier_init(self.output_proj, distribution='uniform', bias=0.)
        self._is_init = True
  1. sampling_offsets 初始化

    • sampling_offsets 的权重初始化为常数 0。
    • 根据头的数量计算角度数组 thetas,范围为 [0, 2π)
    • 将角度转换为网格的初始化值 grid_init,用于设置采样偏移的初始偏移量。
    • 对每个采样点的偏移进行标度调整。
    • grid_init 赋值给 sampling_offsets 的偏置。
  2. attention_weights 初始化

    • attention_weights 的权重和偏置初始化为 0。
  3. value_projoutput_proj 初始化

    • 使用 Xavier 均匀分布对 value_projoutput_proj 的权重进行初始化,偏置初始化为 0。

前向传播

    def forward(self,
                query,
                key=None,
                value=None,
                identity=None,
                query_pos=None,
                key_padding_mask=None,
                reference_points=None,
                spatial_shapes=None,
                level_start_index=None,
                flag='decoder',
                **kwargs):
        """Forward Function of MultiScaleDeformAttention.

        Args:
            query (Tensor): Query of Transformer with shape
                (num_query, bs, embed_dims).
            key (Tensor): The key tensor with shape
                `(num_key, bs, embed_dims)`.
            value (Tensor): The value tensor with shape
                `(num_key, bs, embed_dims)`.
            identity (Tensor): The tensor used for addition, with the
                same shape as `query`. Default None. If None,
                `query` will be used.
            query_pos (Tensor): The positional encoding for `query`.
                Default: None.
            reference_points (Tensor): The normalized reference
                points with shape (bs, num_query, num_levels, 2),
                all elements is range in [0, 1], top-left (0,0),
                bottom-right (1, 1), including padding area.
                or (N, Length_{query}, num_levels, 4), add
                additional two dimensions is (w, h) to
                form reference boxes.
            key_padding_mask (Tensor): ByteTensor for `query`, with
                shape [bs, num_key].
            spatial_shapes (Tensor): Spatial shape of features in
                different levels. With shape (num_levels, 2),
                last dimension represents (h, w).
            level_start_index (Tensor): The start index of each level.
                A tensor has shape ``(num_levels, )`` and can be represented
                as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].

        Returns:
             Tensor: forwarded results with shape [num_query, bs, embed_dims].
        """
参数说明
  • query:查询张量,形状为 [num_query, batch_size, embed_dims]
  • key:键张量,形状为 [num_key, batch_size, embed_dims],默认未使用。
  • value:值张量,形状为 [num_key, batch_size, embed_dims],默认未使用。
  • identity:用于加法的张量,形状与 query 相同。默认情况下,使用 query
  • query_pos:查询的位置编码,默认未提供。
  • reference_points:参考点,形状为 [batch_size, num_query, num_levels, 2][N, Length_{query}, num_levels, 4]
  • key_padding_mask:查询的掩码,形状为 [batch_size, num_key]
  • spatial_shapes:特征在不同层级的空间形状,形状为 [num_levels, 2]
  • level_start_index:每个层级的起始索引,形状为 [num_levels]
  • flag:标志,默认值为 decoder
  • **kwargs:其他关键字参数。

前向传播逻辑

1. 值处理
if value is None:
    assert self.batch_first
    bs, len_bev, c = query.shape
    value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)

  • 检查 value: 如果 value 未提供(通常在某些特定情况下),则通过将 query 堆叠两次来生成 value。这里假设 batch_first 为 True,即批次是第一维度。
  • 重塑 value: 将堆叠后的 value 重塑为 [bs*2, len_bev, c],表示包含两个批次的 value,每个批次都是 query
2. 初始化 identity
if identity is None:
    identity = query

  • 默认 identity: 如果 identity 未提供,则默认使用 query 作为 identityidentity 用于残差连接中保持输入信息。
3. 加上位置编码
if query_pos is not None:
    query = query + query_pos

  • 位置编码: 如果提供了 query_pos,则将其加到 query 上,用于在查询中加入位置信息,使得模型能够理解输入特征的位置关系。
4. 调整批次维度
if not self.batch_first:
    query = query.permute(1, 0, 2)
    value = value.permute(1, 0, 2)

  • 维度转换: 如果 batch_first 为 False,则将 queryvalue 的维度从 [num_query, batch_size, embed_dims] 转置为 [batch_size, num_query, embed_dims],以适应后续处理要求。
5. 检查尺寸和队列
bs,  num_query, embed_dims = query.shape
_, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
assert self.num_bev_queue == 2
  • 检查尺寸: 获取 query 的批次大小、查询数量和嵌入维度,并确保 spatial_shapes 中的所有元素乘积之和等于 num_value
  • 检查 BEV 队列: 确保 BEV 队列的长度为 2,这是实现多帧特征融合的基础。
6. 处理 queryvalue
query = torch.cat([value[:bs], query], -1)
value = self.value_proj(value)

  • 拼接 queryvalue: 将 value 的前 bs 个元素(即当前帧和历史帧)与 query 在最后一个维度上拼接,形成一个新的查询向量。
  • 投影 value: 通过 value_proj 线性层对 value 进行投影,将其映射到适合注意力机制的空间。
7. 掩码处理
if key_padding_mask is not None:
    value = value.masked_fill(key_padding_mask[..., None], 0.0)

  • 应用掩码: 如果提供了 key_padding_mask,则使用掩码将 value 中相应位置的值填充为 0,以忽略这些位置的影响。
8. 重塑 value
value = value.reshape(bs*self.num_bev_queue,
                      num_value, self.num_heads, -1)

  • 重塑 value: 将 value 重塑为 [bs*num_bev_queue, num_value, num_heads, dim_per_head],将维度分解为多头注意力机制的格式。
9. 计算采样偏移和注意力权重
sampling_offsets = self.sampling_offsets(query)
sampling_offsets = sampling_offsets.view(
    bs, num_query, self.num_heads,  self.num_bev_queue, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
    bs, num_query,  self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)

  • 采样偏移: 通过 sampling_offsets 线性层计算每个查询的采样偏移,并将其重塑为 [bs, num_query, num_heads, num_bev_queue, num_levels, num_points, 2]
  • 注意力权重: 通过 attention_weights 线性层计算每个查询在每个采样点的注意力权重,并将其重塑为 [bs, num_query, num_heads, num_bev_queue, num_levels * num_points]
    • 应用 Softmax:对最后一个维度应用 softmax,使得权重在 num_levels * num_points 维度上和为 1。
10. 调整采样偏移和注意力权重的维度
attention_weights = attention_weights.view(bs, num_query,
                                           self.num_heads,
                                           self.num_bev_queue,
                                           self.num_levels,
                                           self.num_points)

attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\
    .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\
    .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)

  • 调整 attention_weights 维度: 将 attention_weights 重塑并调整为 [bs*num_bev_queue, num_query, num_heads, num_levels, num_points] 的形状,方便后续操作。
  • 调整 sampling_offsets 维度: 将 sampling_offsets 重塑并调整为 [bs*num_bev_queue, num_query, num_heads, num_levels, num_points, 2] 的形状,方便后续操作。
11. 计算采样位置
if reference_points.shape[-1] == 2:
    offset_normalizer = torch.stack(
        [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
    sampling_locations = reference_points[:, :, None, :, None, :] \
        + sampling_offsets \
        / offset_normalizer[None, None, None, :, None, :]

elif reference_points.shape[-1] == 4:
    sampling_locations = reference_points[:, :, None, :, None, :2] \
        + sampling_offsets / self.num_points \
        * reference_points[:, :, None, :, None, 2:] \
        * 0.5
else:
    raise ValueError(
        f'Last dim of reference_points must be'
        f' 2 or 4, but get {reference_points.shape[-1]} instead.')

  • 根据参考点计算采样位置
    • 2D 参考点:如果 reference_points 的最后一个维度为 2,则计算偏移并使用 spatial_shapes 进行标准化。
    • 4D 参考点:如果 reference_points 的最后一个维度为 4,则结合宽度和高度信息计算采样位置。
    • 异常处理:否则抛出异常,说明 reference_points 的形状不正确。
12. 调用 Deformable Attention 函数
if torch.cuda.is_available() and value.is_cuda:

    # using fp16 deformable attention is unstable because it performs many sum operations
    if value.dtype == torch.float16:
        MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
    else:
        MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
    output = MultiScaleDeformableAttnFunction.apply(
        value, spatial_shapes, level_start_index, sampling_locations,
        attention_weights, self.im2col_step)
else:

    output = multi_scale_deformable_attn_pytorch(
        value, spatial_shapes, sampling_locations, attention_weights)

  • 选择计算函数
    • 如果在 CUDA 上运行且 value 在 GPU 上,则选择适当的 MultiScaleDeformableAttnFunction 函数(目前固定为 FP32)。
    • 否则,使用 PyTorch 实现的 multi_scale_deformable_attn_pytorch 函数来计算。
  • 执行注意力计算: 调用选定的函数,输入 valuespatial_shapeslevel_start_indexsampling_locationsattention_weights,计算出最终的 output
13. 调整输出的维度
# output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims) -> (num_query, embed_dims, bs*num_bev_queue)
output = output.permute(1, 2, 0)

# fuse history value and current value
# (num_query, embed_dims, bs*num_bev_queue) -> (num_query, embed_dims, bs, num_bev_queue)
output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
output = output.mean(-1)

# (num_query, embed_dims, bs) -> (bs, num_query, embed_dims)
output = output.permute(2, 0, 1)

output = self.output_proj(output)

if not self.batch_first:
    output = output.permute(1, 0, 2)

return self.dropout(output) + identity

  • 调整 output 的维度: 将 output[bs*num_bev_queue, num_query, embed_dims] 转置为 [num_query, embed_dims, bs*num_bev_queue]
  • 融合历史值和当前值
    • output 重塑为 [num_query, embed_dims, bs, num_bev_queue]
    • num_bev_queue 维度进行平均,融合当前帧和历史帧的特征。
  • 再次调整 output 的维度
    • output 转置回 [bs, num_query, embed_dims]
    • 应用 output_proj 线性层进行投影。
    • 如果 batch_first 为 False,再次转置为 [num_query, bs, embed_dims]
  • 应用 Dropout 和残差连接
    • output 应用 Dropout。
    • 返回 outputidentity 的和,即通过残差连接将输入的信息加入输出中。

总结

TemporalSelfAttention 类通过结合多尺度、多头注意力机制和时间序列信息,实现了复杂的特征融合和注意力计算。其主要功能包括:

  1. 处理多帧特征:通过拼接当前帧和历史帧的特征,实现时间序列信息的融合。
  2. 多尺度注意力:支持在多个尺度上进行注意力计算,从而捕捉多尺度的特征信息。
  3. 采样偏移和位置计算:通过计算采样偏移和参考点的位置,灵活地从特征图中提取信息。
  4. 灵活的维度处理:能够处理不同排列方式的输入张量,并在最终返回前调整到期望的维度。
  5. 高效的注意力计算:支持 CUDA 加速的可变形注意力计算函数,提高了计算效率。

通过这些机制,TemporalSelfAttention 模块能够有效地整合多视角和时间序列的信息,在自动驾驶和 3D 感知任务中表现出色。

  • 24
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值