deformable DETR源码——超详细图解

论文地址:https://arxiv.org/pdf/2010.04159.pdf

模型代码:https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/detectors/deformable_detr.py

前言:网上也有很多大佬讲deformable DETR 的原理和源码,但是我觉得纸上得来终觉浅,还是要自己一行代码一行代码地 debug,才能理解得更加透彻。于是有了这篇博客,方便自己后续复习。

1. 数据预处理

mmdet/models/data_preprocessors/data_preprocessor.py

def forward(self, data: dict, training: bool = False) -> dict:
    """Perform normalization,padding and bgr2rgb conversion based on
    ``BaseDataPreprocessor``.

    Args:
        data (dict): Data sampled from dataloader.
        training (bool): Whether to enable training time augmentation.

    Returns:
        dict: Data in the same format as the model input.	将图片统一成一样的格式
    """
    batch_pad_shape = self._get_pad_shape(data)	# [(672, 807), (800, 800)]
    data = super().forward(data=data, training=training)	# 
    inputs, data_samples = data['inputs'], data['data_samples']	# inputs(2, 3, 800, 807)

    if data_samples is not None:
        # NOTE the batched image size information may be useful, e.g.
        # in DETR, this is needed for the construction of masks, which is
        # then used for the transformer_head.
        batch_input_shape = tuple(inputs[0].size()[-2:])	# (596, 652)
        for data_sample, pad_shape in zip(data_samples, batch_pad_shape):
            data_sample.set_metainfo({
                'batch_input_shape': batch_input_shape,
                'pad_shape': pad_shape
            })

        if self.boxtype2tensor:
            samplelist_boxtype2tensor(data_samples)

        if self.pad_mask and training:
            self.pad_gt_masks(data_samples)

        if self.pad_seg and training:
            self.pad_gt_sem_seg(data_samples)

    if training and self.batch_augments is not None:
        for batch_aug in self.batch_augments:
            inputs, data_samples = batch_aug(inputs, data_samples)

    return {'inputs': inputs, 'data_samples': data_samples}

输入:一个 batch 中的 data(是一个字典结构)

data 的数据结构如下所示:inputs 是一个列表,里面存放的是这个 batch 中的 tensor;data_samples 也是一个列表(不需要关注)。

处理:F.pad

输出:处理之后的 data(还是一个字典结构)

对一个 batch 中的图片大小统一成同样大小。如将一个 batch 中的两张大小为(3, 576, 768)和(3, 544, 778)的图片统一成(3, 576 778)。如下图所示:

2. Backbone

这里使用的 backbone 网络为 resnet50,关于resnet50 具体的网络结构这里就不介绍了,不是本篇文章的重点。

mmdet/models/detectors/base_detr.py

def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
    """Extract features.

    Args:
        batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).

    Returns:
        tuple[Tensor]: Tuple of feature maps from neck. Each feature map
        has shape (bs, dim, H, W).
    """
    x = self.backbone(batch_inputs)
    if self.with_neck:
        x = self.neck(x)
    return x

输入:(2, 3, 576, 778)

处理:resnet50

输出:[(2, 512, 72, 98), (2, 1024, 36, 49), (2, 2048, 18, 25)]

3. Neck

输入:[(2, 512, 72, 98), (2, 1024, 36, 49), (2, 2048, 18, 25)]

处理:卷积、池化、激活

输出:((2, 256, 72, 98), (2, 256, 36, 49), (2, 256, 18, 25), (2, 256, 9, 13))

mmdet/models/necks/channel_mapper.py

def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
    """Forward function."""
    assert len(inputs) == len(self.convs)
    outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
    if self.extra_convs:
        for i in range(len(self.extra_convs)):
            if i == 0:
                outs.append(self.extra_convs[0](inputs[-1]))
            else:
                outs.append(self.extra_convs[i](outs[-1]))
    return tuple(outs)

outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] 特征图大小没变,特征维度变成了 256, 结果如下:

经过self.extra_convs之后,新增了一个(9*13)大小的特征图, 结果如下:

4. Head(核心)

def forward_transformer(self,
                        img_feats: Tuple[Tensor],
                        batch_data_samples: OptSampleList = None) -> Dict:
    encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
        img_feats, batch_data_samples)

    encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)

    tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)
    decoder_inputs_dict.update(tmp_dec_in)

    decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
    head_inputs_dict.update(decoder_outputs_dict)
    return head_inputs_dict

def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
    """Extract features.

    Args:
        batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).

    Returns:
        tuple[Tensor]: Tuple of feature maps from neck. Each feature map
        has shape (bs, dim, H, W).
    """
    x = self.backbone(batch_inputs)
    if self.with_neck:
        x = self.neck(x)
    return x

4.1. pre_transformer

进行 transformer 前的一些准备工作,比如:

  • 将img_feats([(batch_size, emb_size, h1, w1), (batch_size, emb_size, h2, w2), (batch_size, emb_size, h3, w3), (batch_size, emb_size, h4, w4)])的格式转换为(batch_size, featurre_nums, emb_size);
  • 记录每一层 feature map 上哪些位置需要被 mask,即不参与计算;
  • 计算每一层 feature map 上每个位置的位置编码
  • 记录每一层 feature map 的大小
  • 记录每一层 feature map 的开始索引位置
  • 记录每一层 feature map 有效特征点的比例
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(img_feats, batch_data_samples)

输入:img_feats 为四个层级的特征图,batch_data_samples 包含当前 batch 中每张图片的相关信息

输出

  • encoder_inputs_dict
    • feat(batch_size, featurre_nums, emb_size) 特征值
    • feat_mask(batch_size, featurre_nums) 特征值 mask,即哪些位置需要忽略,因为图片有填充
    • feat_pos(batch_size, featurre_nums, emb_size) 每个特征点的位置编码
    • spatical_shapes(feature_levels, 2) 每层特征图大小
    • level_start_index(feature_levels,) 每层特征图的开始位置
    • valid_ratios(batch_size, feature_levels, 2)每层特征图上有效特征的比例
  • decoder_inputs_dict
    • memory_mask(batch_size, featurre_nums) 特征值 mask
    • spatial_shapes(feature_levels, 2) 每层特征图大小
    • level_start_index(feature_levels,) 每层特征图的开始位置
    • valid_ratios(batch_size, feature_levels, 2)每层特征图上有效特征的比例

小结: pre_transformer 部分进行 transformer 之前的一些准备工作

masks

masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
    img_h, img_w = img_shape_list[img_id]
    masks[img_id, :img_h, :img_w] = 0

masks 的大小为一个批次中统一的大小,这里是(576, 778),然后将每张图片实际大小的位置用 0 填充。这里的实际大小为(576, 768)和(544, 778)

mlvl_masks、mlvl_pos_embeds

mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
    mlvl_masks.append(
        F.interpolate(masks[None], size=feat.shape[-2:]).to(
            torch.bool).squeeze(0))

mlvl_masks:使用双线性插值算法将(576, 778)缩放为对应特征图的大小(72, 98),(36 49),(18, 25),(9, 13)

mlvl_pos_embeds:存放的是位置编码,每个位置的编码维度为 256

# 对tensor维度进行压缩
feat_flatten = []
lvl_pos_embed_flatten = []
mask_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
        zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
    batch_size, c, h, w = feat.shape
    spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
    # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
    feat = feat.view(batch_size, c, -1).permute(0, 2, 1)			# (2, h*w, 256)
    pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)	# (2, h*w, 256)
    # self.level_embed:(4, 256), 每一层特征图还有一个位置编码,用于区分是哪个特征图
    # 特征图上每一个位置的编码+特征图整体位置编码
    lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)	# (2, h*w, 256)
    # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
    if mask is not None:
        mask = mask.flatten(1)	# (2, h*w)

    feat_flatten.append(feat)
    lvl_pos_embed_flatten.append(lvl_pos_embed)
    mask_flatten.append(mask)
    spatial_shapes.append(spatial_shape)

# (bs, num_feat_points, dim)
feat_flatten = torch.cat(feat_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
if mask_flatten[0] is not None:
    mask_flatten = torch.cat(mask_flatten, 1)
else:
    mask_flatten = None

# (num_level, 2)
spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
level_start_index = torch.cat((
    spatial_shapes.new_zeros((1, )),  # (num_level)
    spatial_shapes.prod(1).cumsum(0)[:-1]))
if mlvl_masks[0] is not None:
    valid_ratios = torch.stack(  # (bs, num_level, 2)
        [self.get_valid_ratio(m) for m in mlvl_masks], 1)
else:
    valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats),
                                          2)

encoder_inputs_dict = dict(
    feat=feat_flatten,
    feat_mask=mask_flatten,
    feat_pos=lvl_pos_embed_flatten,
    spatial_shapes=spatial_shapes,
    level_start_index=level_start_index,
    valid_ratios=valid_ratios)
decoder_inputs_dict = dict(
    memory_mask=mask_flatten,
    spatial_shapes=spatial_shapes,
    level_start_index=level_start_index,
    valid_ratios=valid_ratios)
return encoder_inputs_dict, decoder_inputs_dict

上面代码中,
self.level_embed 的 shape 为(4, 256),表示每一个特征图都有一个层级编码(256 维),用于表示属于哪层特征。

feat_flatten = torch.cat(feat_flatten, 1)的 shape 为(2, 9387, 256),其中9387 = 72*98 + 36*49 + 18*25 + 9*13,相当于是把四层特征图拼接在了一起。如下图所示:

lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)的 shape 为(2, 9387, 256),其中9387 = 72*98 + 36*49 + 18*25 + 9*13。保存的是位置编码信息和层级编码信息。如下图所示:

mask_flatten = torch.cat(mask_flatten, 1)的 shape 为(2, 9387),保存的是每个特征图上的 mask 信息。2 表示 batch_size。有颜色的位置表示被 mask 掉的位置。

spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)的 shape 为(4, 2),保存的是每层特征图的特征大小。[72, 98], [36, 49], [18, 25], [9, 13]

level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))存放的是每层特征图的起始索引位置。[0, 7096, 8820, 9270]

valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)的 shape 为(2, 4, 2),表示每个特征图上有效位置的比例,如下图所示:

以(0.98, 1)为例:0.98 表示这个 batch 中第一张图片的(72*98)的特征图上,h 的有效比例为 0.98,w 的有效比例为 1。有什么用?后面再详细解释。

4.2. forward_encoder

encoder 部分是在干嘛呢,debug 了很多次源码,大概是理解了。用一句话描述就是对输入的特征图结合不同层的特征进行重构。而重构的方法就是每个特征点都需要考虑 4 层特征图上对应点的信息,而每层特征图上对应点的信息又是参考其周围 4 个参考点的信息得到的,且这 4 个参考点贡献的权重都是不一样的。上述过程重复多次就完成了 encoder 部分。

输入

encoder_inputs_dict

  • feat(batch_size, featurre_nums, emb_size) 特征值
  • feat_mask(batch_size, featurre_nums) 特征值 mask,即哪些位置需要忽略,因为图片有填充
  • feat_pos(batch_size, featurre_nums, emb_size) 每个特征点的位置编码
  • spatical_shapes(feature_levels, 2) 每层特征图大小
  • level_start_index(feature_levels,) 每层特征图的开始位置
  • valid_ratios(batch_size, feature_levels, 2)每层特征图上有效特征的比例

输出

encoder_outputs_dict

  • memory:对输入 feature 重构之后的结果,shape 并没有发生变化(batch_size, featurre_nums, emb_size)
  • feat_mask(batch_size, featurre_nums),原封不动返回
  • spatical_shapes(feature_levels, 2) ,原封不动返回

mmdet/models/detectors/base_detr.py

encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)

mmdet/models/detectors/deformable_detr.py

def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
                    feat_pos: Tensor, spatial_shapes: Tensor,
                    level_start_index: Tensor,
                    valid_ratios: Tensor) -> Dict:
    # memory(b, num_value, num_heads*emb)举例(2, 120, 256)
    memory = self.encoder(
        query=feat,
        query_pos=feat_pos,
        key_padding_mask=feat_mask,  # for self_attn
        spatial_shapes=spatial_shapes,
        level_start_index=level_start_index,
        valid_ratios=valid_ratios)
    encoder_outputs_dict = dict(
        memory=memory,
        memory_mask=feat_mask,
        spatial_shapes=spatial_shapes)
    return encoder_outputs_dict

mmdet/models/layers/transformer/deformable_detr_layers.py

def forward(self, query: Tensor, query_pos: Tensor,
            key_padding_mask: Tensor, spatial_shapes: Tensor,
            level_start_index: Tensor, valid_ratios: Tensor,
            **kwargs) -> Tensor:
    reference_points = self.get_encoder_reference_points(
        spatial_shapes, valid_ratios, device=query.device)
    for layer in self.layers:
        query = layer(
            query=query,
            query_pos=query_pos,
            key_padding_mask=key_padding_mask,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
            reference_points=reference_points,
            **kwargs)
    return query

4.2.1. reference_points

输入:spatial_shapes, valid_ratios

输出:reference_points(batch_size, num_features, 4, 2)

reference_points:特征图上的每个点都有相对于每一层特征图的相对位置。

用一段代码解释reference_points 表示的含义

# 假设batch_size=1
reference_points = torch.tensor(
                   [[[0.1, 0.1], [0.3, 0.1], [0.5, 0.1], [0.7, 0.1], [0.9, 0.1],
                     [0.1, 0.3], [0.3, 0.3], [0.5, 0.3], [0.7, 0.3], [0.9, 0.3],
                     [0.1, 0.5], [0.3, 0.5], [0.5, 0.5], [0.7, 0.5], [0.9, 0.5],
                     [0.1, 0.7], [0.3, 0.7], [0.5, 0.7], [0.7, 0.7], [0.9, 0.7],
                     [0.1, 0.9], [0.3, 0.9], [0.5, 0.9], [0.7, 0.9], [0.9, 0.9],

                     [0.12, 0.12], [0.37, 0.12], [0.62, 0.12], [0.87, 0.12],
                     [0.12, 0.37], [0.37, 0.37], [0.62, 0.37], [0.87, 0.37],
                     [0.12, 0.62], [0.37, 0.62], [0.62, 0.62], [0.87, 0.62],
                     [0.12, 0.87], [0.37, 0.87], [0.62, 0.87], [0.87, 0.87],

                     [0.16, 0.16], [0.50, 0.16], [0.83, 0.16],
                     [0.16, 0.50], [0.50, 0.50], [0.83, 0.50],
                     [0.16, 0.83], [0.50, 0.83], [0.83, 0.83],

                     [0.25, 0.25], [0.75, 0.25],
                     [0.25, 0.75], [0.75, 0.75]]])
# 一个batch_size中只有一张图片的话,每一层特征图上h和w的valid_ratio都是1
valid_ratios = torch.tensor([[[1, 1], [1, 1], [1, 1], [1, 1]]])
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
print(reference_points)

以(0.1, 0.1)这个点为例,分别映射到每一层特征图上的位置如下图所示:

mmdet/models/layers/transformer/deformable_detr_layers.py

def get_encoder_reference_points(
        spatial_shapes: Tensor, valid_ratios: Tensor,
        device: Union[torch.device, str]) -> Tensor:
    reference_points_list = []
    # spatial_shapes:(72, 98), (36, 49), (18, 25), (9, 13)
    for lvl, (H, W) in enumerate(spatial_shapes):
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(
                0.5, H - 0.5, H, dtype=torch.float32, device=device),
            torch.linspace(
                0.5, W - 0.5, W, dtype=torch.float32, device=device))
        ref_y = ref_y.reshape(-1)[None] / (
            valid_ratios[:, None, lvl, 1] * H)
        ref_x = ref_x.reshape(-1)[None] / (
            valid_ratios[:, None, lvl, 0] * W)
        ref = torch.stack((ref_x, ref_y), -1)
        reference_points_list.append(ref)
    reference_points = torch.cat(reference_points_list, 1)
    # [bs, sum(hw), num_level, 2]
    reference_points = reference_points[:, :, None] * valid_ratios[:, None]
    return reference_points

上面代码中,

ref_y, ref_x = torch.meshgrid(
    torch.linspace(
        0.5, H - 0.5, H, dtype=torch.float32, device=device),
    torch.linspace(
        0.5, W - 0.5, W, dtype=torch.float32, device=device))

四个spatial_shapes 对应的ref_y, ref_x 的结果如下所示:

将横纵坐标进行归一化

ref_y = ref_y.reshape(-1)[None] / (
    valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
    valid_ratios[:, None, lvl, 0] * W)

将横纵坐标进行拼接

ref = torch.stack((ref_x, ref_y), -1)

将 reference_points 中的每个点映射到每层特征图

reference_points = torch.cat(reference_points_list, 1)
# [bs, sum(hw), num_level, 2]
reference_points = reference_points[:, :, None] * valid_ratios[:, None]

下面以(4, 5), (2, 3)两层特征图大小来画图解释reference_points 的构建过程

4.2.2. DeformableDetrTransformerEncoderLayer

包含如下几层

  • (self_attn): MultiScaleDeformableAttention(
  • (dropout): Dropout(p=0.1, inplace=False)
  • (sampling_offsets): Linear(in_features=256, out_features=256, bias=True)
  • (attention_weights): Linear(in_features=256, out_features=128, bias=True)
  • (value_proj): Linear(in_features=256, out_features=256, bias=True)
  • (output_proj): Linear(in_features=256, out_features=256, bias=True)
4.2.2.1.MultiScaleDeformableAttention(self_attn)

输入

  • query=query, # q 矩阵( 总的特征数, 特征维度)
  • query_pos=query_pos, # 位置矩阵( 总的特征数, 特征维度)
  • key_padding_mask=key_padding_mask, # mask 矩阵( 2, 总的特征数),即哪些点不计算 attention
  • spatial_shapes=spatial_shapes, # 不同尺度下的空间形状
  • level_start_index=level_start_index, # 每一层的开始索引
  • valid_ratios=valid_ratios, # 特征图有效比例
  • reference_points=reference_points, # 每一个特征点的参考点坐标

输出

query:融合多层特征之后的 query,特征维度保持不变。

query = self.self_attn(
    query=query,
    key=query,
    value=query,
    query_pos=query_pos,
    key_pos=query_pos,
    key_padding_mask=key_padding_mask,
    **kwargs)
def forward(self,
            query: torch.Tensor,									# query
            key: Optional[torch.Tensor] = None,						# query
            value: Optional[torch.Tensor] = None,					# query
            identity: Optional[torch.Tensor] = None,
            query_pos: Optional[torch.Tensor] = None,
            key_padding_mask: Optional[torch.Tensor] = None,
            reference_points: Optional[torch.Tensor] = None,
            spatial_shapes: Optional[torch.Tensor] = None,
            level_start_index: Optional[torch.Tensor] = None,
            **kwargs) -> torch.Tensor:
    """Forward Function of MultiScaleDeformAttention.

    Args:
        query (torch.Tensor): Query of Transformer with shape
            (num_query, bs, embed_dims).
        key (torch.Tensor): The key tensor with shape
            `(num_key, bs, embed_dims)`.
        value (torch.Tensor): The value tensor with shape
            `(num_key, bs, embed_dims)`.
        identity (torch.Tensor): The tensor used for addition, with the
            same shape as `query`. Default None. If None,
            `query` will be used.
        query_pos (torch.Tensor): The positional encoding for `query`.
            Default: None.
        key_padding_mask (torch.Tensor): ByteTensor for `query`, with
            shape [bs, num_key].
        reference_points (torch.Tensor):  The normalized reference
            points with shape (bs, num_query, num_levels, 2),
            all elements is range in [0, 1], top-left (0,0),
            bottom-right (1, 1), including padding area.
            or (N, Length_{query}, num_levels, 4), add
            additional two dimensions is (w, h) to
            form reference boxes.
        spatial_shapes (torch.Tensor): Spatial shape of features in
            different levels. With shape (num_levels, 2),
            last dimension represents (h, w).
        level_start_index (torch.Tensor): The start index of each level.
            A tensor has shape ``(num_levels, )`` and can be represented
            as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].

    Returns:
        torch.Tensor: forwarded results with shape
        [num_query, bs, embed_dims].
    """

    if value is None:
        value = query

    if identity is None:
        identity = query
    if query_pos is not None:
        query = query + query_pos
    if not self.batch_first:
        # change to (bs, num_query ,embed_dims)
        query = query.permute(1, 0, 2)
        value = value.permute(1, 0, 2)

    bs, num_query, _ = query.shape
    bs, num_value, _ = value.shape
    assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
    # 通过一个线性变换层就得到value了,唐宇迪大佬说的神奇,确实神奇
    value = self.value_proj(value)
    if key_padding_mask is not None:
        # 被mask的部分,也就是被填充的部分就不需要计算value值
        value = value.masked_fill(key_padding_mask[..., None], 0.0)
    # 矩阵变换(b, num_value, num_heads, emb)
    value = value.view(bs, num_value, self.num_heads, -1)
    # 再通过一个线性变换成就得到采样点的偏移值了(b, num_query, num_heads, num_levels, num_points, 2)
    # 有点离谱了
    sampling_offsets = self.sampling_offsets(query).view(
        bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
    # 再通过一个全连接层就得到采样点的权重了(b, num_query, num_heads, num_levels*num_points)
    # 一个特征点会采样4个采样点,而这4个采样点的权重就是attention_weights
    # 这个query简直太腻害了,做什么都可以
    attention_weights = self.attention_weights(query).view(
        bs, num_query, self.num_heads, self.num_levels * self.num_points)
    # 将权重进行softmax
    attention_weights = attention_weights.softmax(-1)
    # 矩阵变换(b, num_query, num_heads, num_levels, num_points)
    attention_weights = attention_weights.view(bs, num_query,
                                               self.num_heads,
                                               self.num_levels,
                                               self.num_points)
    if reference_points.shape[-1] == 2:
        # (w, h)
        offset_normalizer = torch.stack(
            [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
        # 核心代码(b, num_query, num_heads, num_levels, num_points, 2)
        # 看画图解释
        sampling_locations = reference_points[:, :, None, :, None, :] \
            + sampling_offsets \
            / offset_normalizer[None, None, None, :, None, :]
        
    elif reference_points.shape[-1] == 4:
        sampling_locations = reference_points[:, :, None, :, None, :2] \
            + sampling_offsets / self.num_points \
            * reference_points[:, :, None, :, None, 2:] \
            * 0.5
    else:
        raise ValueError(
            f'Last dim of reference_points must be'
            f' 2 or 4, but get {reference_points.shape[-1]} instead.')
    if ((IS_CUDA_AVAILABLE and value.is_cuda)
            or (IS_MLU_AVAILABLE and value.is_mlu)):
        output = MultiScaleDeformableAttnFunction.apply(
            value, spatial_shapes, level_start_index, sampling_locations,
            attention_weights, self.im2col_step)
    else:
        # (b, num_query, num_heads*embed_dims)
        output = multi_scale_deformable_attn_pytorch(
            value, spatial_shapes, sampling_locations, attention_weights)
    # 线性变换
    output = self.output_proj(output)

    if not self.batch_first:
        # (num_query, bs ,embed_dims)
        output = output.permute(1, 0, 2)
    # 加上恒等映射
    return self.dropout(output) + identity
4.2.2.1.1. multi_scale_deformable_attn_pytorch

多尺度的 attention

def multi_scale_deformable_attn_pytorch(
        value: torch.Tensor, value_spatial_shapes: torch.Tensor,
        sampling_locations: torch.Tensor,
        attention_weights: torch.Tensor) -> torch.Tensor:
    """CPU version of multi-scale deformable attention.

    Args:
        value (torch.Tensor): The value has shape
            (bs, num_keys, num_heads, embed_dims//num_heads)
        value_spatial_shapes (torch.Tensor): Spatial shape of
            each feature map, has shape (num_levels, 2),
            last dimension 2 represent (h, w)
        sampling_locations (torch.Tensor): The location of sampling points,
            has shape
            (bs ,num_queries, num_heads, num_levels, num_points, 2),
            the last dimension 2 represent (x, y).
        attention_weights (torch.Tensor): The weight of sampling points used
            when calculate the attention, has shape
            (bs ,num_queries, num_heads, num_levels, num_points),

    Returns:
        torch.Tensor: has shape (bs, num_queries, embed_dims)
    """

    bs, _, num_heads, embed_dims = value.shape
    _, num_queries, num_heads, num_levels, num_points, _ =\
        sampling_locations.shape
    # [(bs, num_keys1, num_heads, embed_dims//num_heads), (bs, num_keys2, num_heads, embed_dims//num_heads), 
    #  (bs, num_keys3, num_heads, embed_dims//num_heads), (bs, num_keys4, num_heads, embed_dims//num_heads), ]
    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
                             dim=1)
    # 将sampling_locations的值缩放在-1到1之间,因为F.grid_sample要求在-1到1之间
    sampling_grids = 2 * sampling_locations - 1
    sampling_value_list = []
    for level, (H_, W_) in enumerate(value_spatial_shapes):
        # bs, H_*W_, num_heads, embed_dims ->
        # bs, H_*W_, num_heads*embed_dims ->
        # bs, num_heads*embed_dims, H_*W_ ->
        # bs*num_heads, embed_dims, H_, W_
        value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
            bs * num_heads, embed_dims, H_, W_)
        # bs, num_queries, num_heads, num_points, 2 ->
        # bs, num_heads, num_queries, num_points, 2 ->
        # bs*num_heads, num_queries, num_points, 2
        sampling_grid_l_ = sampling_grids[:, :, :,
                                          level].transpose(1, 2).flatten(0, 1)
        # bs*num_heads, embed_dims, num_queries, num_points
        sampling_value_l_ = F.grid_sample(
            value_l_,
            sampling_grid_l_,
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False)
        sampling_value_list.append(sampling_value_l_)
    # (bs, num_queries, num_heads, num_levels, num_points) ->
    # (bs, num_heads, num_queries, num_levels, num_points) ->
    # (bs, num_heads, 1, num_queries, num_levels*num_points)
    attention_weights = attention_weights.transpose(1, 2).reshape(
        bs * num_heads, 1, num_queries, num_levels * num_points)
    # 每个特征点
    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
              attention_weights).sum(-1).view(bs, num_heads * embed_dims,
                                              num_queries)
    return output.transpose(1, 2).contiguous()

下图解释了 output 的实现过程

4.3. pre_decoder

decoder 前的准备工作。包括:

生成候选框、重构 memory:根据输入的memory 生成候选框坐标,每个特征点生成四个坐标值output_proposals,分别表示(cx, cy, w, h);并且对输入的 memory 进行重构,得到 out_memory;

分类分支、回归分支:对out_memory 经过分类分支得到每个特征点属于每个类别的得分值,其中第一个维度表示属于前景的得分值。对out_memory 经过回归分支得到每个特征点需要预测的框的坐标值,将这个坐标值+output_proposals,得到enc_outputs_coord_unact;

筛选候选框:从enc_outputs_coord_unact 中筛选num_queries 个候选框。

生成 query 和 query_pos:将筛选出来的num_queries 个坐标值进行位置编码,然后经过 Linear 和 LayerNorm 层,得到 512 为特征。其中,前 256 个维度作为 query 位置编码,后 256 个维度作为 query 特征。

tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)

输入

encoder_outputs_dict

  • memory:重构之后的 query(b, num_querys, num_heads*embed_dims)
  • feat_mask:特征 mask(b, num_query)
  • spatial_shapes:每层特征图大小(4, 2)

输出

decoder_inputs_dict

  • query:对筛选出来的坐标值经过 Linear 和 LayerNorm 层,得到的 256 维度的特征;
  • query_pos:对筛选出来的坐标值经过 Linear 和 LayerNorm 层,得到的 256 维度的特征,用来表示位置特征;
  • memory:没有做任何处理,直接输出;
  • reference_points:将筛选出来的num_queries个坐标值进行激活,将其缩放在0-1之间,就得到了reference_points

head_inputs_dict

  • enc_outputs_class:每个特征点属于每个类别的得分值(bs, num_querys, num_calsses)
  • enc_outputs_coord:激活之后的候选框坐标(bs, num_querys, 4)

4.3.1. 两阶段

4.3.1.1. gen_encoder_output_proposals

对重构后的 query(memory) 生成出候选框,即每个 query 点都会回归出四个坐标点(cx, cy, w, h)。这四个坐标点是偏移量。

output_memory, output_proposals = self.gen_encoder_output_proposals(memory, memory_mask, spatial_shapes)
def gen_encoder_output_proposals(
        self, memory: Tensor, memory_mask: Tensor,
        spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]:
    """Generate proposals from encoded memory. The function will only be
    used when `as_two_stage` is `True`.

    Args:
        memory (Tensor): The output embeddings of the Transformer encoder,
            has shape (bs, num_feat_points, dim).
        memory_mask (Tensor): ByteTensor, the padding mask of the memory,
            has shape (bs, num_feat_points).
        spatial_shapes (Tensor): Spatial shapes of features in all levels,
            has shape (num_levels, 2), last dimension represents (h, w).

    Returns:
        tuple: A tuple of transformed memory and proposals.

        - output_memory (Tensor): The transformed memory for obtaining
          top-k proposals, has shape (bs, num_feat_points, dim).
        - output_proposals (Tensor): The inverse-normalized proposal, has
          shape (batch_size, num_keys, 4) with the last dimension arranged
          as (cx, cy, w, h).
    """

    bs = memory.size(0)
    proposals = []
    _cur = 0  # start index in the sequence of the current level
    for lvl, HW in enumerate(spatial_shapes):
        H, W = HW

        if memory_mask is not None:
            mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(
                bs, H, W, 1)
            valid_H = torch.sum(~mask_flatten_[:, :, 0, 0],
                                1).unsqueeze(-1)
            valid_W = torch.sum(~mask_flatten_[:, 0, :, 0],
                                1).unsqueeze(-1)
            scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
        else:
            if not isinstance(HW, torch.Tensor):
                HW = memory.new_tensor(HW)
            scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(
                0, H - 1, H, dtype=torch.float32, device=memory.device),
            torch.linspace(
                0, W - 1, W, dtype=torch.float32, device=memory.device))
        grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
        grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
        wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
        proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
        proposals.append(proposal)
        _cur += (H * W)
    output_proposals = torch.cat(proposals, 1)
    # do not use `all` to make it exportable to onnx
    output_proposals_valid = (
        (output_proposals > 0.01) & (output_proposals < 0.99)).sum(
            -1, keepdim=True) == output_proposals.shape[-1]
    # inverse_sigmoid
    output_proposals = torch.log(output_proposals / (1 - output_proposals))
    if memory_mask is not None:
        output_proposals = output_proposals.masked_fill(
            memory_mask.unsqueeze(-1), float('inf'))
    output_proposals = output_proposals.masked_fill(
        ~output_proposals_valid, float('inf'))

    output_memory = memory
    if memory_mask is not None:
        output_memory = output_memory.masked_fill(
            memory_mask.unsqueeze(-1), float(0))
    output_memory = output_memory.masked_fill(~output_proposals_valid,
                                              float(0))
    output_memory = self.memory_trans_fc(output_memory)
    output_memory = self.memory_trans_norm(output_memory)
    # [bs, sum(hw), 2]
    return output_memory, output_proposals

上面代码中,output_proposals = torch.log(output_proposals / (1 - output_proposals))函数的图像如下。如果输入值越靠近 0 或者越靠近 1,那么输出值就是无穷小或者无穷大。

以下面的示例代码,画图解释output_proposals 的构建过程。

import torch
from torch import nn

torch.manual_seed(42)
memory = torch.rand((1, 26, 8))
memory_mask = torch.tensor([[False, False, False, False, True,
                             False, False, False, False, True,
                             False, False, False, False, True,
                             False, False, False, False, True,
                             False, False, True,
                             False, False, True]])
spatial_shapes = torch.tensor([[4, 5], [2, 3]])

memory_trans_fc = nn.Linear(8, 8)
memory_trans_norm = nn.LayerNorm(8)


def gen_encoder_output_proposals(memory, memory_mask, spatial_shapes):
    bs = memory.size(0)
    proposals = []
    _cur = 0  # start index in the sequence of the current level
    for lvl, HW in enumerate(spatial_shapes):
        H, W = HW

        if memory_mask is not None:
            mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(
                bs, H, W, 1)
            valid_H = torch.sum(~mask_flatten_[:, :, 0, 0],
                                1).unsqueeze(-1)
            valid_W = torch.sum(~mask_flatten_[:, 0, :, 0],
                                1).unsqueeze(-1)
            scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
        else:
            if not isinstance(HW, torch.Tensor):
                HW = memory.new_tensor(HW)
            scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(
                0, H - 1, H, dtype=torch.float32, device=memory.device),
            torch.linspace(
                0, W - 1, W, dtype=torch.float32, device=memory.device))
        grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
        grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
        wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
        proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
        proposals.append(proposal)
        _cur += (H * W)
    output_proposals = torch.cat(proposals, 1)
    # do not use `all` to make it exportable to onnx
    output_proposals_valid = (
                                     (output_proposals > 0.01) & (output_proposals < 0.99)).sum(
        -1, keepdim=True) == output_proposals.shape[-1]
    # inverse_sigmoid
    output_proposals = torch.log(output_proposals / (1 - output_proposals))
    if memory_mask is not None:
        output_proposals = output_proposals.masked_fill(
            memory_mask.unsqueeze(-1), float('inf'))
    output_proposals = output_proposals.masked_fill(
        ~output_proposals_valid, float('inf'))

    output_memory = memory
    # 两步筛选:先对mask位置填充0
    if memory_mask is not None:
        output_memory = output_memory.masked_fill(
            memory_mask.unsqueeze(-1), float(0))
    # 再对< 0.01且 > 0.99的位置填充0
    output_memory = output_memory.masked_fill(~output_proposals_valid,
                                              float(0))
    # 线性层
    output_memory = memory_trans_fc(output_memory)
    # norm层
    output_memory = memory_trans_norm(output_memory)
    # [bs, sum(hw), 2]
    return output_memory, output_proposals


output_memory, output_proposals = gen_encoder_output_proposals(memory, memory_mask, spatial_shapes)

output_memory

output_proposals

4.3.1.2. cls_branches

分类分支: 通过一个线性层得到每个 query 点属于每一个类别的得分值

# (bs, num_query, num_class),如(2, 120, 80)
enc_outputs_class = self.bbox_head.cls_branches[
                self.decoder.num_layers](
                    output_memory)
4.3.1.3. reg_branches

回归分支:通过一个线性层得到每个 query 的四个坐标值

# (bs, num_query, 4),如(2, 120, 4)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
    self.decoder.num_layers](output_memory) + output_proposals
# 经过激活函数之后的结果,用来将坐标值缩放到0-1之间
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
4.3.1.4. 筛选 proposals

根据enc_outputs_class 的第一个维度表示属于前景(foreground)的得分值,因此,通过根据第一个维度值的大小来筛选出num_queries 个 proposals。

# (bs, num_queries),如(2, 100)
topk_proposals = torch.topk(
    enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
# 然后根据筛选出来的索引值,从enc_outputs_coord_unact中筛选对应索引的坐标值
topk_coords_unact = torch.gather(
    enc_outputs_coord_unact, 1,
    topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
# (bs, num_queries, 4),如(2, 100, 4)
topk_coords_unact = topk_coords_unact.detach()
4.3.1.5. reference_points

将筛选出来的num_queries个坐标值进行激活,将其缩放在0-1之间,就得到了reference_points。这里的reference_points 同 encoder 阶段的作用一样,用于表示每个特征点在每一层特征图上的相对位置。

# (bs, num_queries, 4)
reference_points = topk_coords_unact.sigmoid()
4.3.1.6. query 和query_pos

将筛选出来的num_queries 个坐标值进行位置编码,然后经过 Linear 和 LayerNorm 层,得到 512 为特征。其中,前 256 个维度作为 query 位置编码,后 256 个维度作为 query 特征。

pos_trans_out = self.pos_trans_fc(
    self.get_proposal_pos_embed(topk_coords_unact))
pos_trans_out = self.pos_trans_norm(pos_trans_out)
query_pos, query = torch.split(pos_trans_out, c, dim=2)

4.4. forward_decoder

decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)

4.4.1. self_attn

代码和 encoder 阶段的代码是一样的,只是这里的 query 是经过筛选之后的 N 个候选框。表示什么意思呢,即这 N 个候选框之间进行 attention 计算

query = self.self_attn(
    query=query,
    key=query,
    value=query,
    query_pos=query_pos,
    key_pos=query_pos,
    attn_mask=self_attn_mask,
    **kwargs)
def forward(self,
            query,
            key=None,
            value=None,
            identity=None,
            query_pos=None,
            key_pos=None,
            attn_mask=None,
            key_padding_mask=None,
            **kwargs):
    """Forward function for `MultiheadAttention`.

    **kwargs allow passing a more general data flow when combining
    with other operations in `transformerlayer`.

    Args:
        query (Tensor): The input query with shape [num_queries, bs,
            embed_dims] if self.batch_first is False, else
            [bs, num_queries embed_dims].
        key (Tensor): The key tensor with shape [num_keys, bs,
            embed_dims] if self.batch_first is False, else
            [bs, num_keys, embed_dims] .
            If None, the ``query`` will be used. Defaults to None.
        value (Tensor): The value tensor with same shape as `key`.
            Same in `nn.MultiheadAttention.forward`. Defaults to None.
            If None, the `key` will be used.
        identity (Tensor): This tensor, with the same shape as x,
            will be used for the identity link.
            If None, `x` will be used. Defaults to None.
        query_pos (Tensor): The positional encoding for query, with
            the same shape as `x`. If not None, it will
            be added to `x` before forward function. Defaults to None.
        key_pos (Tensor): The positional encoding for `key`, with the
            same shape as `key`. Defaults to None. If not None, it will
            be added to `key` before forward function. If None, and
            `query_pos` has the same shape as `key`, then `query_pos`
            will be used for `key_pos`. Defaults to None.
        attn_mask (Tensor): ByteTensor mask with shape [num_queries,
            num_keys]. Same in `nn.MultiheadAttention.forward`.
            Defaults to None.
        key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
            Defaults to None.

    Returns:
        Tensor: forwarded results with shape
        [num_queries, bs, embed_dims]
        if self.batch_first is False, else
        [bs, num_queries embed_dims].
    """

    if key is None:
        key = query
    if value is None:
        value = key
    if identity is None:
        identity = query
    if key_pos is None:
        if query_pos is not None:
            # use query_pos if key_pos is not available
            if query_pos.shape == key.shape:
                key_pos = query_pos
            else:
                warnings.warn(f'position encoding of key is'
                              f'missing in {self.__class__.__name__}.')
    if query_pos is not None:
        query = query + query_pos
    if key_pos is not None:
        key = key + key_pos

    # Because the dataflow('key', 'query', 'value') of
    # ``torch.nn.MultiheadAttention`` is (num_query, batch,
    # embed_dims), We should adjust the shape of dataflow from
    # batch_first (batch, num_query, embed_dims) to num_query_first
    # (num_query ,batch, embed_dims), and recover ``attn_output``
    # from num_query_first to batch_first.
    if self.batch_first:
        query = query.transpose(0, 1)
        key = key.transpose(0, 1)
        value = value.transpose(0, 1)

    out = self.attn(
        query=query,
        key=key,
        value=value,
        attn_mask=attn_mask,
        key_padding_mask=key_padding_mask)[0]

    if self.batch_first:
        out = out.transpose(0, 1)

    return identity + self.dropout_layer(self.proj_drop(out))

进入self.attn 方法,通过 pytorch 内置的 multi_head_attention_forward 函数来实现 self_attn

# attn_output(300, 2, 256)
# attn_output_weights(2, 300, 300)
attn_output, attn_output_weights = F.multi_head_attention_forward(
    query, key, value, self.embed_dim, self.num_heads,
    self.in_proj_weight, self.in_proj_bias,
    self.bias_k, self.bias_v, self.add_zero_attn,
    self.dropout, self.out_proj.weight, self.out_proj.bias,
    training=self.training,
    key_padding_mask=key_padding_mask, need_weights=need_weights,
    attn_mask=attn_mask, average_attn_weights=average_attn_weights)

4.4.2. cross_attn

query = self.cross_attn(
    query=query,
    key=key,
    value=value,
    query_pos=query_pos,
    key_pos=key_pos,
    attn_mask=cross_attn_mask,
    key_padding_mask=key_padding_mask,
    **kwargs)

query 是经过 decoder 阶段 self_attn 之后的 query,而 value 则是 encoder 阶段的 value,即encoder 阶段 self_attn 之后的 query,因此是 cross_attn。代码和 encoder 阶段的MultiScaleDeformableAttention 是一样的。

5. Loss

debug 代码太累了,调整一下再战...

  • 24
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值