# 在temporal self-attn时候
# value :[8(bs*2), 20000(num_query), 8(head), 32(dim/head)] 8头注意力,每个头32维度 历史bev信息+当前query
# spatial_shapes:(1,2)形状 存的是bev图的尺寸高宽
# sampling_locations(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels(1), self.num_points, 2) 历史bev还有当前bev采样点的索引,已经归一化
# attention_weights 采样点的权重:(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points)
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
# TSA:value只有1层
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
# 将坐标分布从(0,1)转到(-1,1) 因为下面grid_sample函数需要这种形式
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# [(bs*2(bev_queue)*num_heads), 32(dim/head), bev_h, bev_w ]
# 大小为num_query(bev_h*w)的BEV网格下,每个网格有32(dim/head)维度的特征
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# (bs*self.num_bev_queue*self.num_heads, num_query, self.num_points, 2)
# flatten(0, 1) 01维度相乘
# 大小为num_query的BEV网格下,每个网格有self.num_points个采样点,每个点的坐标为2维的xy
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# grid_sample函数作用:根据sampling_grid_l_存储的采样点位置去采样value_l_中的特征
# 使得sampling_grid_l_获得特征。
# value_l_是形状为(bev_h,bev_w)二维特征图,sampling_grid_l_存的是二维坐标点
# 维度为(N,C,Hin,Win) 的input,维度为(N,Hout,Wout,2) 的grid,则该函数output的维度为(N,C,Hout,Wout)
# 因此sampling_value_l_的形状为 (bs*self.num_bev_queue*self.num_heads, 32(dim/head), num_query, self.num_points)
# sampling_value_l_就是采样特征
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# 每个采样特征的权重attention_weights,也是注意力得分
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
#flatten(-2)后: (bs*self.num_bev_queue*self.num_heads, 32(dim/head), num_query, 1*self.num_points) stack增加了一个维度,又flatten
# 权重*采样特征后: (bs*self.num_bev_queue*self.num_heads, 32(dim/head), num_query, self.num_points)
# sum(-1)后,最后一个维度消除:(bs*self.num_bev_queue*self.num_heads, 32(dim/head), num_query)
# sum(-1)就是对所有采样点的特征进行加权求和,得到最终的BEV特征(历史和当前的,还没进行融合,返回后再进行融合)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
# import pdb;pdb.set_trace()
# (bs*self.num_bev_queue, num_query, self.num_heads*32(dim/head))
return output.transpose(1, 2).contiguous()
05-22
907
11-20
1823
12-25
2841
07-15
1700
12-24
4288
09-08
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交