1、SpatialCrossAttention
SpatialCrossAttention
是一个用于 BEVFormer(Bird's Eye View Former)的注意力模块,用于将多个摄像头的数据进行融合。
类定义和初始化
@ATTENTION.register_module()
class SpatialCrossAttention(BaseModule):
def __init__(self,
embed_dims=256,
num_cams=6,
pc_range=None,
dropout=0.1,
init_cfg=None,
batch_first=False,
deformable_attention=dict(
type='MSDeformableAttention3D',
embed_dims=256,
num_levels=4),
**kwargs
):
super(SpatialCrossAttention, self).__init__(init_cfg)
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
self.pc_range = pc_range
self.fp16_enabled = False
self.deformable_attention = build_attention(deformable_attention)
self.embed_dims = embed_dims
self.num_cams = num_cams
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.batch_first = batch_first
self.init_weight()
embed_dims
:嵌入维度,默认值为 256。num_cams
:摄像头数量,默认值为 6。pc_range
:点云范围,默认值为 None。dropout
:dropout 比例,默认值为 0.1。init_cfg
:初始化配置,默认值为 None。batch_first
:batch 维度是否在第一位,默认值为 False。deformable_attention
:可变形注意力配置,默认值为MSDeformableAttention3D
。
初始化过程中,创建了 dropout 层、deformable attention 模块和输出投影层,并调用了 init_weight
方法初始化权重。
初始化权重方法 (init_weight
)
def init_weight(self):
"""Default initialization for Parameters of Module."""
xavier_init(self.output_proj, distribution='uniform', bias=0.)
使用 Xavier 初始化方法对输出投影层 output_proj
进行初始化。
前向传播方法 (forward
)
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam'))
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
reference_points_cam=None,
bev_mask=None,
level_start_index=None,
flag='encoder',
**kwargs):
参数说明
query
:查询张量,形状为(num_query, bs, embed_dims)
。key
:键张量,形状为(num_key, bs, embed_dims)
。value
:值张量,形状为(num_key, bs, embed_dims)
。residual
:用于残差连接的张量,与query
形状相同,默认值为 None。query_pos
:查询张量的位置信息编码,默认值为 None。key_padding_mask
:键填充掩码,默认值为 None。reference_points
:归一化的参考点,形状为(bs, num_query, 4)
或(N, Length_{query}, num_levels, 4)
。reference_points_cam
:每个摄像头的参考点,形状为(num_cams, bs, num_query, 4)
。bev_mask
:BEV (Bird's Eye View) 掩码。spatial_shapes
:特征在不同层级的空间形状,形状为(num_levels, 2)
。level_start_index
:每个层级的起始索引,形状为(num_levels,)
。flag
:标志位,用于区分编码器和解码器,默认值为encoder
。
前向传播逻辑
- 处理
key
和value
if key is None:
key = query
if value is None:
value = key
- 处理
residual
和query_pos
if residual is None:
inp_residual = query
slots = torch.zeros_like(query)
if query_pos is not None:
query = query + query_pos
- 如果
residual
为 None,则将inp_residual
设置为query
,并创建一个与query
形状相同的零张量slots
。 - 如果提供了
query_pos
,则将其加到query
上。
- 重塑
query
和reference_points_cam
bs, num_query, _ = query.size()
D = reference_points_cam.size(3)
indexes = []
for i, mask_per_img in enumerate(bev_mask):
index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
indexes.append(index_query_per_img)
max_len = max([len(each) for each in indexes])
queries_rebatch = query.new_zeros(
[bs, self.num_cams, max_len, self.embed_dims])
reference_points_rebatch = reference_points_cam.new_zeros(
[bs, self.num_cams, max_len, D, 2])
for j in range(bs):
for i, reference_points_per_img in enumerate(reference_points_cam):
index_query_per_img = indexes[i]
queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
- 获取
query
的批次大小和查询数量。 - 获取参考点的维度
D
。 - 根据 BEV 掩码
bev_mask
获取每张图像的查询索引indexes
。 - 找到所有图像中查询索引的最大长度
max_len
。 - 初始化
queries_rebatch
和reference_points_rebatch
张量,以适应多摄像头的数据格式。 - 重新组织
query
和reference_points_cam
以适应新格式。
- 多尺度变形注意力
num_cams, l, bs, embed_dims = key.shape
key = key.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims)
value = value.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims)
queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims),
key=key,
value=value,
reference_points=reference_points_rebatch.view(
bs*self.num_cams, max_len, D, 2),
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
flag=flag, **kwargs)
- 将
key
和value
进行维度转换,使其适应多摄像头数据格式。 - 调用
deformable_attention
进行多尺度变形注意力计算。
- 恢复原始形状并进行残差连接
queries = queries.view(bs, self.num_cams, max_len, self.embed_dims)
queries = queries.sum(1).view(bs, max_len, self.embed_dims)
for j in range(bs):
for i, mask_per_img in enumerate(bev_mask):
index_query_per_img = indexes[i]
slots[j, index_query_per_img] = queries[j, :len(index_query_per_img)]
slots = self.dropout(slots)
return inp_residual + self.output_proj(slots)
- 将
queries
重塑回原始形状。 - 将所有摄像头的查询结果进行求和。
- 将结果填充到
slots
张量中。 - 经过 dropout 层后,进行残差连接并返回最终结果。
这个注意力模块通过跨摄像头的变形注意力机制,有效地融合了多摄像头数据,并提升了 BEVFormer 在处理 3D 数据时的性能。
2、MSDeformableAttention3D
MSDeformableAttention3D
类实现了多尺度变形注意力机制,旨在处理 3D 空间中的特征数据。
类定义和初始化
@ATTENTION.register_module()
class MSDeformableAttention3D(BaseModule):
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=8,
im2col_step=64,
dropout=0.1,
init_cfg=None,
batch_first=False,
norm_cfg=None,
**kwargs):
super(MSDeformableAttention3D, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.im2col_step = im2col_step
self.batch_first = batch_first
self.norm_cfg = norm_cfg
self.sampling_offsets = nn.Linear(embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.dropout = nn.Dropout(dropout)
self.init_weight()
self.fp16_enabled = False
embed_dims
:嵌入维度,默认值为 256。num_heads
:注意力头的数量,默认值为 8。num_levels
:特征层级的数量,默认值为 4。num_points
:每个注意力头中每个参考点的采样点数量,默认值为 8。im2col_step
:im2col
操作的步长,默认值为 64。dropout
:dropout 比例,默认值为 0.1。init_cfg
:初始化配置,默认值为 None。batch_first
:batch 维度是否在第一位,默认值为 False。norm_cfg
:归一化配置,默认值为 None。
初始化过程中,创建了采样偏移量、注意力权重、值投影层和输出投影层,并调用了 init_weight
方法初始化权重。
初始化权重方法 (init_weight
)
def init_weight(self):
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.num_heads, 1, 1, 2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_init(self.attention_weights, 0.)
xavier_init(self.value_proj, distribution='uniform')
xavier_init(self.output_proj, distribution='uniform')
- 使用
constant_init
对采样偏移量sampling_offsets
进行初始化。 - 使用预定义的角度
thetas
和grid_init
来初始化采样偏移量的偏置。 - 使用
constant_init
对注意力权重attention_weights
进行初始化。 - 使用 Xavier 初始化方法对值投影层
value_proj
和输出投影层output_proj
进行初始化。
前向传播方法 (forward
)
@force_fp32(apply_to=('query', 'key', 'value', 'reference_points'))
def forward(self,
query,
key=None,
value=None,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
参数说明
query
:查询张量,形状为(num_query, bs, embed_dims)
。key
:键张量,默认值为 None。value
:值张量,默认值为 None。residual
:用于残差连接的张量,与query
形状相同,默认值为 None。query_pos
:查询张量的位置信息编码,默认值为 None。key_padding_mask
:键填充掩码,默认值为 None。reference_points
:归一化的参考点,形状为(bs, num_query, num_levels, 2)
。spatial_shapes
:特征在不同层级的空间形状,形状为(num_levels, 2)
。level_start_index
:每个层级的起始索引,形状为(num_levels,)
。
前向传播逻辑
- 处理
key
和value
if key is None:
key = query
if value is None:
value = key
如果 key
和 value
为 None,则将它们设置为 query
。
2.处理 residual
和 query_pos
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
- 如果
residual
为 None,则将inp_residual
设置为query
。 - 如果提供了
query_pos
,则将其加到query
上。
3.重塑 query
和 reference_points
bs, num_query, _ = query.size()
_, num_value, _ = value.size()
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points)
attention_weights = attention_weights.softmax(-1)
- 获取
query
的批次大小和查询数量。 - 获取
value
的数量。 - 对
value
进行投影。 - 如果提供了
key_padding_mask
,则对value
进行掩码填充。 - 计算采样偏移量
sampling_offsets
和注意力权重attention_weights
,并对注意力权重进行 softmax 归一化。
4、多尺度变形注意力计算
if reference_points.size(-1) == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.size(-1) == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.size(-1)} instead.')
- 如果
reference_points
的最后一个维度是 2,则计算采样位置sampling_locations
。 - 如果
reference_points
的最后一个维度是 4,则计算采样位置sampling_locations
。 - 否则抛出错误。
5、调用多尺度变形注意力函数
if torch.cuda.is_available() and value.is_cuda:
if not self.training and self.fp16_enabled:
output = MultiScaleDeformableAttnFunction_fp16.apply(
value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)
else:
output = MultiScaleDeformableAttnFunction_fp32.apply(
value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)
- 根据是否在 CUDA 上运行以及是否启用了 fp16,调用不同的多尺度变形注意力函数。
6、返回结果
output = self.output_proj(output)
return inp_residual + self.dropout(output)
- 对输出进行投影,并加上残差连接,返回最终结果。
这个 MSDeformableAttention3D
类通过多尺度变形注意力机制,能够高效处理 3D 特征数据,从而提升模型在处理 3D 空间任务时的性能。