SpatialCrossAttention——空间注意力

1、SpatialCrossAttention

SpatialCrossAttention 是一个用于 BEVFormer(Bird's Eye View Former)的注意力模块,用于将多个摄像头的数据进行融合。

类定义和初始化
@ATTENTION.register_module()
class SpatialCrossAttention(BaseModule):
    def __init__(self,
                 embed_dims=256,
                 num_cams=6,
                 pc_range=None,
                 dropout=0.1,
                 init_cfg=None,
                 batch_first=False,
                 deformable_attention=dict(
                     type='MSDeformableAttention3D',
                     embed_dims=256,
                     num_levels=4),
                 **kwargs
                 ):
        super(SpatialCrossAttention, self).__init__(init_cfg)
        self.init_cfg = init_cfg
        self.dropout = nn.Dropout(dropout)
        self.pc_range = pc_range
        self.fp16_enabled = False
        self.deformable_attention = build_attention(deformable_attention)
        self.embed_dims = embed_dims
        self.num_cams = num_cams
        self.output_proj = nn.Linear(embed_dims, embed_dims)
        self.batch_first = batch_first
        self.init_weight()

  • embed_dims:嵌入维度,默认值为 256。
  • num_cams:摄像头数量,默认值为 6。
  • pc_range:点云范围,默认值为 None。
  • dropout:dropout 比例,默认值为 0.1。
  • init_cfg:初始化配置,默认值为 None。
  • batch_first:batch 维度是否在第一位,默认值为 False。
  • deformable_attention:可变形注意力配置,默认值为 MSDeformableAttention3D

初始化过程中,创建了 dropout 层、deformable attention 模块和输出投影层,并调用了 init_weight 方法初始化权重。

初始化权重方法 (init_weight)
    def init_weight(self):
        """Default initialization for Parameters of Module."""
        xavier_init(self.output_proj, distribution='uniform', bias=0.)

使用 Xavier 初始化方法对输出投影层 output_proj 进行初始化。

前向传播方法 (forward)
    @force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam'))
    def forward(self,
                query,
                key,
                value,
                residual=None,
                query_pos=None,
                key_padding_mask=None,
                reference_points=None,
                spatial_shapes=None,
                reference_points_cam=None,
                bev_mask=None,
                level_start_index=None,
                flag='encoder',
                **kwargs):

参数说明
  • query:查询张量,形状为 (num_query, bs, embed_dims)
  • key:键张量,形状为 (num_key, bs, embed_dims)
  • value:值张量,形状为 (num_key, bs, embed_dims)
  • residual:用于残差连接的张量,与 query 形状相同,默认值为 None。
  • query_pos:查询张量的位置信息编码,默认值为 None。
  • key_padding_mask:键填充掩码,默认值为 None。
  • reference_points:归一化的参考点,形状为 (bs, num_query, 4)(N, Length_{query}, num_levels, 4)
  • reference_points_cam:每个摄像头的参考点,形状为 (num_cams, bs, num_query, 4)
  • bev_mask:BEV (Bird's Eye View) 掩码。
  • spatial_shapes:特征在不同层级的空间形状,形状为 (num_levels, 2)
  • level_start_index:每个层级的起始索引,形状为 (num_levels,)
  • flag:标志位,用于区分编码器和解码器,默认值为 encoder
前向传播逻辑
  1. 处理 keyvalue
 
        if key is None:
            key = query
        if value is None:
            value = key

  1. 处理 residualquery_pos
 
        if residual is None:
            inp_residual = query
            slots = torch.zeros_like(query)
        if query_pos is not None:
            query = query + query_pos

  • 如果 residual 为 None,则将 inp_residual 设置为 query,并创建一个与 query 形状相同的零张量 slots
  • 如果提供了 query_pos,则将其加到 query 上。
  1. 重塑 queryreference_points_cam
 
        bs, num_query, _ = query.size()

        D = reference_points_cam.size(3)
        indexes = []
        for i, mask_per_img in enumerate(bev_mask):
            index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
            indexes.append(index_query_per_img)
        max_len = max([len(each) for each in indexes])

        queries_rebatch = query.new_zeros(
            [bs, self.num_cams, max_len, self.embed_dims])
        reference_points_rebatch = reference_points_cam.new_zeros(
            [bs, self.num_cams, max_len, D, 2])
        
        for j in range(bs):
            for i, reference_points_per_img in enumerate(reference_points_cam):   
                index_query_per_img = indexes[i]
                queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
                reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]

  • 获取 query 的批次大小和查询数量。
  • 获取参考点的维度 D
  • 根据 BEV 掩码 bev_mask 获取每张图像的查询索引 indexes
  • 找到所有图像中查询索引的最大长度 max_len
  • 初始化 queries_rebatchreference_points_rebatch 张量,以适应多摄像头的数据格式。
  • 重新组织 queryreference_points_cam 以适应新格式。
  1. 多尺度变形注意力
 
        num_cams, l, bs, embed_dims = key.shape

        key = key.permute(2, 0, 1, 3).reshape(
            bs * self.num_cams, l, self.embed_dims)
        value = value.permute(2, 0, 1, 3).reshape(
            bs * self.num_cams, l, self.embed_dims)

        queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims),
                                            key=key,
                                            value=value,
                                            reference_points=reference_points_rebatch.view(
                                                bs*self.num_cams, max_len, D, 2),
                                            spatial_shapes=spatial_shapes,
                                            level_start_index=level_start_index, 
                                            flag=flag, **kwargs)

  • keyvalue 进行维度转换,使其适应多摄像头数据格式。
  • 调用 deformable_attention 进行多尺度变形注意力计算。
  1. 恢复原始形状并进行残差连接
 
        queries = queries.view(bs, self.num_cams, max_len, self.embed_dims)
        queries = queries.sum(1).view(bs, max_len, self.embed_dims)

        for j in range(bs):
            for i, mask_per_img in enumerate(bev_mask):
                index_query_per_img = indexes[i]
                slots[j, index_query_per_img] = queries[j, :len(index_query_per_img)]
                
        slots = self.dropout(slots)
        return inp_residual + self.output_proj(slots)

  • queries 重塑回原始形状。
  • 将所有摄像头的查询结果进行求和。
  • 将结果填充到 slots 张量中。
  • 经过 dropout 层后,进行残差连接并返回最终结果。

这个注意力模块通过跨摄像头的变形注意力机制,有效地融合了多摄像头数据,并提升了 BEVFormer 在处理 3D 数据时的性能。

2、MSDeformableAttention3D

MSDeformableAttention3D 类实现了多尺度变形注意力机制,旨在处理 3D 空间中的特征数据。

类定义和初始化
@ATTENTION.register_module()
class MSDeformableAttention3D(BaseModule):
    def __init__(self,
                 embed_dims=256,
                 num_heads=8,
                 num_levels=4,
                 num_points=8,
                 im2col_step=64,
                 dropout=0.1,
                 init_cfg=None,
                 batch_first=False,
                 norm_cfg=None,
                 **kwargs):
        super(MSDeformableAttention3D, self).__init__(init_cfg)
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.num_levels = num_levels
        self.num_points = num_points
        self.im2col_step = im2col_step
        self.batch_first = batch_first
        self.norm_cfg = norm_cfg

        self.sampling_offsets = nn.Linear(embed_dims, num_heads * num_levels * num_points * 2)
        self.attention_weights = nn.Linear(embed_dims, num_heads * num_levels * num_points)
        self.value_proj = nn.Linear(embed_dims, embed_dims)
        self.output_proj = nn.Linear(embed_dims, embed_dims)
        self.dropout = nn.Dropout(dropout)
        self.init_weight()
        self.fp16_enabled = False

  • embed_dims:嵌入维度,默认值为 256。
  • num_heads:注意力头的数量,默认值为 8。
  • num_levels:特征层级的数量,默认值为 4。
  • num_points:每个注意力头中每个参考点的采样点数量,默认值为 8。
  • im2col_stepim2col 操作的步长,默认值为 64。
  • dropout:dropout 比例,默认值为 0.1。
  • init_cfg:初始化配置,默认值为 None。
  • batch_first:batch 维度是否在第一位,默认值为 False。
  • norm_cfg:归一化配置,默认值为 None。

初始化过程中,创建了采样偏移量、注意力权重、值投影层和输出投影层,并调用了 init_weight 方法初始化权重。

初始化权重方法 (init_weight)
    def init_weight(self):
        constant_init(self.sampling_offsets, 0.)
        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.num_heads, 1, 1, 2).repeat(1, self.num_levels, self.num_points, 1)
        for i in range(self.num_points):
            grid_init[:, :, i, :] *= i + 1
        with torch.no_grad():
            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
        constant_init(self.attention_weights, 0.)
        xavier_init(self.value_proj, distribution='uniform')
        xavier_init(self.output_proj, distribution='uniform')

  • 使用 constant_init 对采样偏移量 sampling_offsets 进行初始化。
  • 使用预定义的角度 thetasgrid_init 来初始化采样偏移量的偏置。
  • 使用 constant_init 对注意力权重 attention_weights 进行初始化。
  • 使用 Xavier 初始化方法对值投影层 value_proj 和输出投影层 output_proj 进行初始化。
前向传播方法 (forward)
    @force_fp32(apply_to=('query', 'key', 'value', 'reference_points'))
    def forward(self,
                query,
                key=None,
                value=None,
                residual=None,
                query_pos=None,
                key_padding_mask=None,
                reference_points=None,
                spatial_shapes=None,
                level_start_index=None,
                **kwargs):

参数说明
  • query:查询张量,形状为 (num_query, bs, embed_dims)
  • key:键张量,默认值为 None。
  • value:值张量,默认值为 None。
  • residual:用于残差连接的张量,与 query 形状相同,默认值为 None。
  • query_pos:查询张量的位置信息编码,默认值为 None。
  • key_padding_mask:键填充掩码,默认值为 None。
  • reference_points:归一化的参考点,形状为 (bs, num_query, num_levels, 2)
  • spatial_shapes:特征在不同层级的空间形状,形状为 (num_levels, 2)
  • level_start_index:每个层级的起始索引,形状为 (num_levels,)
前向传播逻辑
  1. 处理 keyvalue
        if key is None:
            key = query
        if value is None:
            value = key

如果 keyvalue 为 None,则将它们设置为 query。 

2.处理 residualquery_pos

 
        if residual is None:
            inp_residual = query
        if query_pos is not None:
            query = query + query_pos

  • 如果 residual 为 None,则将 inp_residual 设置为 query
  • 如果提供了 query_pos,则将其加到 query 上。

3.重塑 queryreference_points

        bs, num_query, _ = query.size()
        _, num_value, _ = value.size()

        value = self.value_proj(value)
        if key_padding_mask is not None:
            value = value.masked_fill(key_padding_mask[..., None], 0.0)

        sampling_offsets = self.sampling_offsets(query).view(
            bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
        attention_weights = self.attention_weights(query).view(
            bs, num_query, self.num_heads, self.num_levels, self.num_points)
        attention_weights = attention_weights.softmax(-1)
  • 获取 query 的批次大小和查询数量。
  • 获取 value 的数量。
  • value 进行投影。
  • 如果提供了 key_padding_mask,则对 value 进行掩码填充。
  • 计算采样偏移量 sampling_offsets 和注意力权重 attention_weights,并对注意力权重进行 softmax 归一化。

4、多尺度变形注意力计算

        if reference_points.size(-1) == 2:
            offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.size(-1) == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                + sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] \
                * 0.5
        else:
            raise ValueError(f'Last dim of reference_points must be'
                             f' 2 or 4, but get {reference_points.size(-1)} instead.')

  • 如果 reference_points 的最后一个维度是 2,则计算采样位置 sampling_locations
  • 如果 reference_points 的最后一个维度是 4,则计算采样位置 sampling_locations
  • 否则抛出错误。

5、调用多尺度变形注意力函数

        if torch.cuda.is_available() and value.is_cuda:
            if not self.training and self.fp16_enabled:
                output = MultiScaleDeformableAttnFunction_fp16.apply(
                    value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)
            else:
                output = MultiScaleDeformableAttnFunction_fp32.apply(
                    value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)
  • 根据是否在 CUDA 上运行以及是否启用了 fp16,调用不同的多尺度变形注意力函数。

6、返回结果

        output = self.output_proj(output)
        return inp_residual + self.dropout(output)

  • 对输出进行投影,并加上残差连接,返回最终结果。

这个 MSDeformableAttention3D 类通过多尺度变形注意力机制,能够高效处理 3D 特征数据,从而提升模型在处理 3D 空间任务时的性能。

  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值