TemporalSelfAttention
类实现了一个用于 BEVFormer 的时间自注意力模块。这种注意力机制基于 Deformable-Detr 的思路,旨在处理多尺度和多视角的特征,同时结合时间序列信息来增强特征表达。在 BEVFormer 中,时间自注意力模块是关键组件,它能够将当前帧和历史帧的特征进行融合和注意力计算,从而提升模型对动态场景的理解能力。
类的定义和初始化
@ATTENTION.register_module()
class TemporalSelfAttention(BaseModule):
"""An attention module used in BEVFormer based on Deformable-Detr.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to True.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
num_bev_queue (int): Number of BEV frames in the queue. Default is 2.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
num_bev_queue=2,
im2col_step=64,
dropout=0.1,
batch_first=True,
norm_cfg=None,
init_cfg=None):
参数说明
embed_dims
:嵌入维度,默认为 256。num_heads
:多头注意力机制中的头数量,默认为 8。num_levels
:特征层的数量,默认为 4。num_points
:每个查询在每个头中采样的点的数量,默认为 4。num_bev_queue
:BEV 队列的长度,表示当前帧和历史帧的数量,默认为 2。im2col_step
:im2col
操作的步长,默认为 64。dropout
:用于输入的 dropout 比率,默认为 0.1。batch_first
:表示输入张量的形状是否为(batch, n, embed_dim)
,默认为 True。norm_cfg
:归一化层的配置,默认为 None。init_cfg
:初始化配置,默认为 None。
初始化过程
- 父类初始化:调用
BaseModule
的初始化方法。 - 维度检查:确保
embed_dims
能够被num_heads
整除,否则抛出异常。 - 警告非二次幂的维度:如果每个注意力头的维度不是 2 的幂次,会发出警告,因为 CUDA 实现中处理二次幂的维度更有效。
- 初始化线性层:
sampling_offsets
:用于计算采样偏移的线性层,输入维度为embed_dims * num_bev_queue
,输出维度为num_bev_queue * num_heads * num_levels * num_points * 2
。attention_weights
:用于计算注意力权重的线性层,输入维度为embed_dims * num_bev_queue
,输出维度为num_bev_queue * num_heads * num_levels * num_points
。value_proj
:对值进行投影的线性层。output_proj
:对输出进行投影的线性层。
- 权重初始化:调用
init_weights
方法进行权重初始化。
权重初始化
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels * self.num_bev_queue, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
-
sampling_offsets
初始化:- 将
sampling_offsets
的权重初始化为常数 0。 - 根据头的数量计算角度数组
thetas
,范围为[0, 2π)
。 - 将角度转换为网格的初始化值
grid_init
,用于设置采样偏移的初始偏移量。 - 对每个采样点的偏移进行标度调整。
- 将
grid_init
赋值给sampling_offsets
的偏置。
- 将
-
attention_weights
初始化:- 将
attention_weights
的权重和偏置初始化为 0。
- 将
-
value_proj
和output_proj
初始化:- 使用 Xavier 均匀分布对
value_proj
和output_proj
的权重进行初始化,偏置初始化为 0。
- 使用 Xavier 均匀分布对
前向传播
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
flag='decoder',
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
参数说明
query
:查询张量,形状为[num_query, batch_size, embed_dims]
。key
:键张量,形状为[num_key, batch_size, embed_dims]
,默认未使用。value
:值张量,形状为[num_key, batch_size, embed_dims]
,默认未使用。identity
:用于加法的张量,形状与query
相同。默认情况下,使用query
。query_pos
:查询的位置编码,默认未提供。reference_points
:参考点,形状为[batch_size, num_query, num_levels, 2]
或[N, Length_{query}, num_levels, 4]
。key_padding_mask
:查询的掩码,形状为[batch_size, num_key]
。spatial_shapes
:特征在不同层级的空间形状,形状为[num_levels, 2]
。level_start_index
:每个层级的起始索引,形状为[num_levels]
。flag
:标志,默认值为decoder
。**kwargs
:其他关键字参数。
前向传播逻辑
1. 值处理
if value is None:
assert self.batch_first
bs, len_bev, c = query.shape
value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)
- 检查
value
: 如果value
未提供(通常在某些特定情况下),则通过将query
堆叠两次来生成value
。这里假设batch_first
为 True,即批次是第一维度。 - 重塑
value
: 将堆叠后的value
重塑为[bs*2, len_bev, c]
,表示包含两个批次的value
,每个批次都是query
。
2. 初始化 identity
if identity is None:
identity = query
- 默认
identity
: 如果identity
未提供,则默认使用query
作为identity
。identity
用于残差连接中保持输入信息。
3. 加上位置编码
if query_pos is not None:
query = query + query_pos
- 位置编码: 如果提供了
query_pos
,则将其加到query
上,用于在查询中加入位置信息,使得模型能够理解输入特征的位置关系。
4. 调整批次维度
if not self.batch_first:
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
- 维度转换: 如果
batch_first
为 False,则将query
和value
的维度从[num_query, batch_size, embed_dims]
转置为[batch_size, num_query, embed_dims]
,以适应后续处理要求。
5. 检查尺寸和队列
bs, num_query, embed_dims = query.shape
_, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
assert self.num_bev_queue == 2
- 检查尺寸: 获取
query
的批次大小、查询数量和嵌入维度,并确保spatial_shapes
中的所有元素乘积之和等于num_value
。 - 检查 BEV 队列: 确保 BEV 队列的长度为 2,这是实现多帧特征融合的基础。
6. 处理 query
和 value
query = torch.cat([value[:bs], query], -1)
value = self.value_proj(value)
- 拼接
query
和value
: 将value
的前bs
个元素(即当前帧和历史帧)与query
在最后一个维度上拼接,形成一个新的查询向量。 - 投影
value
: 通过value_proj
线性层对value
进行投影,将其映射到适合注意力机制的空间。
7. 掩码处理
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
- 应用掩码: 如果提供了
key_padding_mask
,则使用掩码将value
中相应位置的值填充为 0,以忽略这些位置的影响。
8. 重塑 value
value = value.reshape(bs*self.num_bev_queue,
num_value, self.num_heads, -1)
- 重塑
value
: 将value
重塑为[bs*num_bev_queue, num_value, num_heads, dim_per_head]
,将维度分解为多头注意力机制的格式。
9. 计算采样偏移和注意力权重
sampling_offsets = self.sampling_offsets(query)
sampling_offsets = sampling_offsets.view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
- 采样偏移: 通过
sampling_offsets
线性层计算每个查询的采样偏移,并将其重塑为[bs, num_query, num_heads, num_bev_queue, num_levels, num_points, 2]
。 - 注意力权重: 通过
attention_weights
线性层计算每个查询在每个采样点的注意力权重,并将其重塑为[bs, num_query, num_heads, num_bev_queue, num_levels * num_points]
。- 应用 Softmax:对最后一个维度应用 softmax,使得权重在
num_levels * num_points
维度上和为 1。
- 应用 Softmax:对最后一个维度应用 softmax,使得权重在
10. 调整采样偏移和注意力权重的维度
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_bev_queue,
self.num_levels,
self.num_points)
attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)
- 调整
attention_weights
维度: 将attention_weights
重塑并调整为[bs*num_bev_queue, num_query, num_heads, num_levels, num_points]
的形状,方便后续操作。 - 调整
sampling_offsets
维度: 将sampling_offsets
重塑并调整为[bs*num_bev_queue, num_query, num_heads, num_levels, num_points, 2]
的形状,方便后续操作。
11. 计算采样位置
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
- 根据参考点计算采样位置:
- 2D 参考点:如果
reference_points
的最后一个维度为 2,则计算偏移并使用spatial_shapes
进行标准化。 - 4D 参考点:如果
reference_points
的最后一个维度为 4,则结合宽度和高度信息计算采样位置。 - 异常处理:否则抛出异常,说明
reference_points
的形状不正确。
- 2D 参考点:如果
12. 调用 Deformable Attention 函数
if torch.cuda.is_available() and value.is_cuda:
# using fp16 deformable attention is unstable because it performs many sum operations
if value.dtype == torch.float16:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
else:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
- 选择计算函数:
- 如果在 CUDA 上运行且
value
在 GPU 上,则选择适当的MultiScaleDeformableAttnFunction
函数(目前固定为 FP32)。 - 否则,使用 PyTorch 实现的
multi_scale_deformable_attn_pytorch
函数来计算。
- 如果在 CUDA 上运行且
- 执行注意力计算: 调用选定的函数,输入
value
、spatial_shapes
、level_start_index
、sampling_locations
和attention_weights
,计算出最终的output
。
13. 调整输出的维度
# output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims) -> (num_query, embed_dims, bs*num_bev_queue)
output = output.permute(1, 2, 0)
# fuse history value and current value
# (num_query, embed_dims, bs*num_bev_queue) -> (num_query, embed_dims, bs, num_bev_queue)
output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
output = output.mean(-1)
# (num_query, embed_dims, bs) -> (bs, num_query, embed_dims)
output = output.permute(2, 0, 1)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
- 调整
output
的维度: 将output
从[bs*num_bev_queue, num_query, embed_dims]
转置为[num_query, embed_dims, bs*num_bev_queue]
。 - 融合历史值和当前值:
- 将
output
重塑为[num_query, embed_dims, bs, num_bev_queue]
。 - 对
num_bev_queue
维度进行平均,融合当前帧和历史帧的特征。
- 将
- 再次调整
output
的维度:- 将
output
转置回[bs, num_query, embed_dims]
。 - 应用
output_proj
线性层进行投影。 - 如果
batch_first
为 False,再次转置为[num_query, bs, embed_dims]
。
- 将
- 应用 Dropout 和残差连接:
- 对
output
应用 Dropout。 - 返回
output
和identity
的和,即通过残差连接将输入的信息加入输出中。
- 对
总结
TemporalSelfAttention
类通过结合多尺度、多头注意力机制和时间序列信息,实现了复杂的特征融合和注意力计算。其主要功能包括:
- 处理多帧特征:通过拼接当前帧和历史帧的特征,实现时间序列信息的融合。
- 多尺度注意力:支持在多个尺度上进行注意力计算,从而捕捉多尺度的特征信息。
- 采样偏移和位置计算:通过计算采样偏移和参考点的位置,灵活地从特征图中提取信息。
- 灵活的维度处理:能够处理不同排列方式的输入张量,并在最终返回前调整到期望的维度。
- 高效的注意力计算:支持 CUDA 加速的可变形注意力计算函数,提高了计算效率。
通过这些机制,TemporalSelfAttention
模块能够有效地整合多视角和时间序列的信息,在自动驾驶和 3D 感知任务中表现出色。