BEVFormer开源算法逐行解析(二):Decoder和Det部分

写在前面:

对于BEVFormer算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在tensor维度的变换中帮助读者对算法能有更直观的认识。

本系列我们将对BEVFormer公版代码(开源算法)进行逐行解析,以结合代码理解Bevformer原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。

公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从configs/bevformer中的config文件中清晰体现,我们以bevformer_tiny.py为例3解析代码,Encoder部分已经发出,见《BEVFormer开源算法逐行解析(一):Encoder部分》,本文主要关注BEVFormer的Decoder和Det部分。

对代码的解析和理解主要体现在代码注释中。

1 PerceptionTransformer:

功能:

  • 将encoder层输出的bev_embed传入decoder中
  • 将在BEVFormer中定义的query_embedding按通道拆分成通道数相同的query_pos和query,并传入decoder中;
  • 利用query_pos通过线性层reference_points生成reference_points,并传入decoder;该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点,作用类似于two-stage目标检测中的Region Proposal ;
  • 返回inter_states, inter_references给cls_branches和reg_branches分支得到目标的种类和bboxes。

解析:

#详见《BEVFormer开源算法逐行解析(一):Encoder部分》,用于获得bev_embed
#在decoder中利用CustimMSDeformableAttention将bev_embed与query融合
bev_embed = self.get_bev_features(
    mlvl_feats,
    bev_queries,
    bev_h,
    bev_w,
    grid_length=grid_length,
    bev_pos=bev_pos,
    prev_bev=prev_bev,
    **kwargs)  # bev_embed shape: bs, bev_h*bev_w, embed_dims

bs = mlvl_feats[0].size(0)
#object_query_embed:torch.Size([900, 512])
#query_pos:torch.Size([900, 256]) 
#query:torch.Size([900, 256])
query_pos, query = torch.split(
    object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
#reference_points:torch.Size([1, 900, 3])
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points

#query:torch.Size([900, 1, 256])
query = query.permute(1, 0, 2)
#query_pos:torch.Size([900, 1, 256])
query_pos = query_pos.permute(1, 0, 2)
#bev_embed:torch.Size([50*50, 1, 256]) 
bev_embed = bev_embed.permute(1, 0, 2)

#进入decoder模块!
inter_states, inter_references = self.decoder(
    query=query,
    key=None,
    value=bev_embed,
    query_pos=query_pos,
    reference_points=reference_points,
    reg_branches=reg_branches,
    cls_branches=cls_branches,
    spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
    level_start_index=torch.tensor([0], device=query.device),
    **kwargs)
#返回inter_states, inter_references
#后续用于提供给cls_branches和reg_branches分支得到目标的种类和bboxes
inter_references_out = inter_references

return bev_embed, inter_states, init_reference_out, inter_references_out

2 DetectionTransformerDecoder

功能:

  • 循环进入6个相同的DetrTransformerDecoderLayer,一个DetrTransformerDecoderLayer包含 (‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’),每层输出output和reference_points;
  • 在6层DetrTransformerDecoderLayer遍历完成后,将6层输出的output和reference_points输出。

解析:

#output:torch.Size([900, 1, 256])
output = query
intermediate = []
intermediate_reference_points = []
#循环进入6个相同的DetrTransformerDecoderLayer模块
for lid, layer in enumerate(self.layers):
    #reference_points_input:torch.Size([1, 900, 1, 2])
    #该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点
    reference_points_input = reference_points[..., :2].unsqueeze(
        2)  # BS NUM_QUERY NUM_LEVEL 2
    #进入某一层DetrTransformerDecoderLayer
    output = layer(
        output,
        *args,
        reference_points=reference_points_input,
        key_padding_mask=key_padding_mask,
        **kwargs)
    #output:torch.Size([1, 900, 256])
    output = output.permute(1, 0, 2)

    if reg_branches is not None:
        #tmp:torch.Size([1, 900, 10])
        tmp = reg_branches[lid](output)

        assert reference_points.shape[-1] == 3
        #new_reference_pointtorch.Size([1, 900, 3])
        new_reference_points = torch.zeros_like(reference_points)
        new_reference_points[..., :2] = tmp[
            ..., :2] + inverse_sigmoid(reference_points[..., :2])
        new_reference_points[..., 2:3] = tmp[
            ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])

        new_reference_points = new_reference_points.sigmoid()

        reference_points = new_reference_points.detach()

    output = output.permute(1, 0, 2)
    if self.return_intermediate:
        intermediate.append(output)
        intermediate_reference_points.append(reference_points)
        
#在6层DetrTransformerDecoderLayer遍历完成后,将6层输出的output和reference_points输出。
if self.return_intermediate:
    return torch.stack(intermediate), torch.stack(
        intermediate_reference_points)

return output, reference_points

深色代码部分生成的reference_points结构见下图,其中inverse_sigmoid(pt_reference_points)即为reference_points/Linear(query_pos)

2.1 MultiheadAttention

功能:

  • object_query的多头自注意力机制,如下图所示。

解析:

embed_dim = 256
kdim = embed_dim
vdim = embed_dim
qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim  # True
num_heads = 8
dropout = 0.1
batch_first = False
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
factory_kwargs = {'device': 'cuda', 'dtype': None}
in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
bias_k = bias_v = None
add_zero_attn = False
out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=True, **factory_kwargs)
attn_mask = attn_mask  # None

if batch_first:
     query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

if not qkv_same_embed_dim:
    # attn_output, attn_output_weights = F.multi_head_attention_forward(
    #     query, key, value, self.embed_dim, self.num_heads,
    #     self.in_proj_weight, self.in_proj_bias,
    #     self.bias_k, self.bias_v, self.add_zero_attn,
    #     self.dropout, self.out_proj.weight, self.out_proj.bias,
    #     training=self.training,
    #     key_padding_mask=key_padding_mask, need_weights=need_weights,
    #     attn_mask=attn_mask, use_separate_proj_weight=True,
    #     q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
    #     v_proj_weight=self.v_proj_weight)
    pass
else:
    attn_output, attn_output_weights = F.multi_head_attention_forward(
         query, key, value, _embed_dim, num_heads, in_proj_weight, in_proj_bias,
         bias_k, bias_v, add_zero_attn, dropout, out_proj.weight, out_proj.bias,
         training=True, key_padding_mask=None, need_weights=True, attn_mask=mhaf_attn_mask)
    -------------------------------F.multi_head_attention_forward start----------------------------
    out_proj_weight = out_proj.weight
    out_proj_bias = out_proj.bias
    key = key
    value = value
    embed_dim_to_check = embed_dim
    use_separate_proj_weight = False
    training = True
    key_padding_mask = None
    need_weights = True
    q_proj_weight, k_proj_weight, v_proj_weight = None, None, None
    static_k, static_v = None, None

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape  # torch.Size([900, 1, 256])
    src_len, _, _ = key.shape
    assert embed_dim == embed_dim_to_check, \
        f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
    if isinstance(embed_dim, torch.Tensor):
    #     # embed_dim can be a tensor when JIT tracing
    #     head_dim = embed_dim.div(mhaf_num_heads, rounding_mode='trunc')
        pass
    else:
        head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {mhaf_num_heads}"

    if not use_separate_proj_weight:
        # q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
        # -----------_in_projection_packed start-----------
        # q, k, v, w, b = query, mhaf_key, mhaf_value, mhaf_in_proj_weight, mhaf_in_proj_bias
        # E = query.size(-1)
        if key is value:
            # if query is mhaf_key:
            #     # self-attention
            #     return linear(query, mhaf_in_proj_weight, mhaf_in_proj_bias).chunk(3, dim=-1)
            # else:
            #     # encoder-decoder attention
            #     w_q, w_kv = mhaf_in_proj_weight.split([E, E * 2])
            #     if mhaf_in_proj_bias is None:
            #         b_q = b_kv = None
            #     else:
            #         b_q, b_kv = mhaf_in_proj_bias.split([E, E * 2])
            #     return (linear(query, w_q, b_q),) + linear(mhaf_key, w_kv, b_kv).chunk(2, dim=-1)
            pass
        else:
            w_q, w_k, w_v = in_proj_weight.chunk(3)
            if in_proj_bias is None:
                # b_q = b_k = b_v = None
                pass
            else:
                b_q, b_k, b_v = in_proj_bias.chunk(3)
            # return linear(query, w_q, b_q), linear(mhaf_key, w_k, b_k), linear(mhaf_value, w_v, b_v)
            # F.linear(x, A, b): return x @ A.T + b
            query, key, value = F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v)
            #                                   query + pt_query_pos      query + pt_query_pos                 query
    # ------------_in_projection_packed end------------
    # else:
    #     assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
    #     assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
    #     assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
    #     if in_proj_bias is None:
    #         b_q = b_k = b_v = None
    #     else:
    #         b_q, b_k, b_v = in_proj_bias.chunk(3)
    #     q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

    #
    # reshape q, k, v for multihead attention and make em batch first
    query = query.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 1, 256] -> [900, 8, 32] -> [8, 900, 32]
    if static_k is None:
        key = key.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 8, 32] -> [8, 900, 32]
    # else:
    #     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
    #     assert mhaf_static_k.size(0) == bsz * mhaf_num_heads, \
    #         f"expecting static_k.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_k.size(0)}"
    #     assert mhaf_static_k.size(2) == head_dim, \
    #         f"expecting static_k.size(2) of {head_dim}, but got {mhaf_static_k.size(2)}"
    #     mhaf_key = mhaf_static_k
    if static_v is None:
        value = value.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 8, 32] -> [8, 900, 32]
    # else:
    #     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
    #     assert mhaf_static_v.size(0) == bsz * mhaf_num_heads, \
    #         f"expecting static_v.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_v.size(0)}"
    #     assert mhaf_static_v.size(2) == head_dim, \
    #         f"expecting static_v.size(2) of {head_dim}, but got {mhaf_static_v.size(2)}"
    #     mhaf_value = mhaf_static_v

    # update source sequence length after adjustments
    src_len = key.size(1)

    attn_output, attn_output_weights = _scaled_dot_product_attention(query, key, value, attn_mask, dropout)
    # ------------_scaled_dot_product_attention start------------
    # q: Tensor,
    # k: Tensor,
    # v: Tensor,
    # attn_mask: Optional[Tensor] = None,
    # dropout_p: float = 0.0,
    B, Nt, E = query.shape  # torch.Size([8, 900, 32]), mhaf_key and mhaf_value is same shape.
    query = query / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    attn = torch.bmm(query, key.transpose(-2, -1))  # [8, 900, 32] @ [8, 32, 900] -> [8, 900, 900]
    # if mhaf_attn_mask is not None:
    #     attn += mhaf_attn_mask
    attn = F.softmax(attn, dim=-1)
    if dropout > 0.0:
        attn = F.dropout(attn, p=dropout)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, value)  # [8, 900, 900] @ [8, 900, 32] -> # torch.Size([8, 900, 32])
    # return output, attn
    attn_output, attn_output_weights = output, attn
    # -------------_scaled_dot_product_attention end-------------
    # tgt_len: 900  # [8, 900, 32]->[900, 8, 32]->[900, 1, 256]
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)  # torch.Size([900, 1, 256])
    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)  # nn.Linear

out = attn_output
# ------------------------------self.attn end------------------------------

# return mha_identity + self.dropout_layer(self.proj_drop(out))
query = identity + dropout_layer(mha_proj_drop(out))
# torch.Size([900, 1, 256]) + # torch.Size([900, 1, 256])

2.2 CustomMSDeformableAttention

功能:

  • 利用可变形注意力机制将encoder模块输出的bev_embed融入object_query,如下图所示;
  • 输出该层的output,将其作为下一层DetrTransformerDecoderLayer的输入,并利用该层output生成该层对应的reference_points。

解析:

#-------------------------CustomMSDeformableAttention init(in part)---------------------------------
sampling_offsets = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points * 2).cuda()
attention_weights = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points).cuda()
value_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()
output_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()
#-------------------------CustomMSDeformableAttention init(in part)---------------------------------
if value is None:
    value = query

if identity is None:
    identity = query
if query_pos is not None:
    query = query + query_pos
if not self.batch_first:
    # change to (bs, num_query ,embed_dims)
    #query:torch.Size([1, 900, 256])
    query = query.permute(1, 0, 2)
    #value(即bev_embed):torch.Size([1, 50*50, 256])
    value = value.permute(1, 0, 2)

bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value

#value(即bev_embed):torch.Size([1, 50*50, 256])
value = self.value_proj(value)
if key_padding_mask is not None:
    value = value.masked_fill(key_padding_mask[..., None], 0.0)
#value:torch.Size([1, 50*50, 8, 32]),为多头做准备
value = value.view(bs, num_value, self.num_heads, -1)

sampling_offsets = self.sampling_offsets(query).view(
    bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
#    1,    900,          8,            1,             4,             2
attention_weights = self.attention_weights(query).view(
    bs, num_query, self.num_heads, self.num_levels * self.num_points)
#    1,    900,          8,                      4,             
attention_weights = attention_weights.softmax(-1)

#attention_weights:torch.Size([1, 900, 8, 1, 32])
attention_weights = attention_weights.view(bs, num_query,
                                            self.num_heads,
                                            self.num_levels,
                                            self.num_points)
#reference_points:torch.Size([1, 900, 1, 2])                                            
if reference_points.shape[-1] == 2:
    offset_normalizer = torch.stack(
        [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
    #sampling_locations:torch.Size([1, 900, 8, 1, 4, 2])
    sampling_locations = reference_points[:, :, None, :, None, :] \
        + sampling_offsets \
        / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
    sampling_locations = reference_points[:, :, None, :, None, :2] \
        + sampling_offsets / self.num_points \
        * reference_points[:, :, None, :, None, 2:] \
        * 0.5
else:
    raise ValueError(
        f'Last dim of reference_points must be'
        f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:

    # using fp16 deformable attention is unstable because it performs many sum operations
    if value.dtype == torch.float16:
        MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
    else:
        MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
    output = MultiScaleDeformableAttnFunction.apply(
        value, spatial_shapes, level_start_index, sampling_locations,
        attention_weights, self.im2col_step)
else:
    #output:torch.Size([1, 900, 256]) 
    #可变形注意力机制,利用query从value(bev_embed)中提取有用信息
    output = multi_scale_deformable_attn_pytorch(
        value, spatial_shapes, sampling_locations, attention_weights)
        
#output:torch.Size([1, 900, 256])
output = self.output_proj(output)

if not self.batch_first:
    # (num_query, bs ,embed_dims)
    output = output.permute(1, 0, 2)

return self.dropout(output) + identity

3 cls_branches&®_branches

功能:

  • 利用decoder输出的bev_embed, inter_states(6层输出的outs), init_reference_out(由query_pos生成的初始reference_points), inter_references_out(6层输出的reference_points)生成目标类别和bboxes;
  • 生成包含bev_embed、 all_cls_scores、all_bbox_preds在内的outs,其中 all_cls_scores、all_bbox_preds用于计算Loss、梯度回传;bev_embed可用于segmentation等任务,进行BEV视角下的语义分割。

解析:

#以下变量的含义见《BEVFormer开源算法逐行解析(一):Encoder部分》
bs, num_cam, _, _, _ = mlvl_feats[0].shape
dtype = mlvl_feats[0].dtype
object_query_embeds = self.query_embedding.weight.to(dtype)
bev_queries = self.bev_embedding.weight.to(dtype)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                        device=bev_queries.device).to(dtype)
bev_pos = self.positional_encoding(bev_mask).to(dtype)

if only_bev:  # only use encoder to obtain BEV features, TODO: refine the workaround
    return self.transformer.get_bev_features(
        mlvl_feats,
        bev_queries,
        self.bev_h,
        self.bev_w,
        grid_length=(self.real_h / self.bev_h,
                        self.real_w / self.bev_w),
        bev_pos=bev_pos,
        img_metas=img_metas,
        prev_bev=prev_bev,
    )
else:
    #outputs就是object_query_embeds、bev_pos、bev_queries、img_metas和mlvl_feats
    #输入encoder和decoder模块后的最终输出
    #outputs:bev_embed, inter_states, init_reference_out, inter_references_out
    outputs = self.transformer(
        mlvl_feats,
        bev_queries,
        object_query_embeds,
        self.bev_h,
        self.bev_w,
        grid_length=(self.real_h / self.bev_h,
                        self.real_w / self.bev_w),
        bev_pos=bev_pos,
        reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501
        cls_branches=self.cls_branches if self.as_two_stage else None,
        img_metas=img_metas,
        prev_bev=prev_bev
)

bev_embed, hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
    if lvl == 0:
        reference = init_reference
    else:
        reference = inter_references[lvl - 1]
    reference = inverse_sigmoid(reference)
    #outputs_class:torch.Size([1, 900, 10])
    outputs_class = self.cls_branches[lvl](hs[lvl])
    #tmp:torch.Size([1, 900, 10])
    tmp = self.reg_branches[lvl](hs[lvl])

    # TODO: check the shape of reference
    assert reference.shape[-1] == 3
    tmp[..., 0:2] += reference[..., 0:2]
    tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
    tmp[..., 4:5] += reference[..., 2:3]
    tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
    #下面" *(self.pc_range[3] -self.pc_range[0]) + self.pc_range[0]",
    #是为了将目标bboxes中心点x、y、z坐标恢复到实际尺度
    tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] -
                        self.pc_range[0]) + self.pc_range[0])
    tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] -
                        self.pc_range[1]) + self.pc_range[1])
    tmp[..., 4:5] = (tmp[..., 4:5] * (self.pc_range[5] -
                        self.pc_range[2]) + self.pc_range[2])

    # TODO: check if using sigmoid
    outputs_coord = tmp
    outputs_classes.append(outputs_class)
    outputs_coords.append(outputs_coord)
#outputs_classes:torch.Size([6, 1, 900, 10])
outputs_classes = torch.stack(outputs_classes)
#outputs_coords:torch.Size([6, 1, 900, 10])
outputs_coords = torch.stack(outputs_coords)

outs = {
    'bev_embed': bev_embed,
    'all_cls_scores': outputs_classes,
    'all_bbox_preds': outputs_coords,
    'enc_cls_scores': None,
    'enc_bbox_preds': None,
}

#outs输出后就可以和class_labels和bboxe_labels一起计算Loss,
#然后反向传播梯度,更新模型中的可学习参数:
#各个线性层、object_query_embeds、bev_queries、bev_pos等
return outs

深色代码部分生成的tmp[0:2]和tmp[4:5]结构见下图,实质上就是"DetectionTransformerDecoder"中生成的reference_points。

结语:

至此,BEVFormer中的Encoder和Decoder部分的逐行代码解析就完成了,如果后续有需求也可以再出一期关于解析Loss计算的文档,这部分比较基础,有兴趣的同学也可以先结合源码自学。

  • 26
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值