论文地址:https://arxiv.org/pdf/2010.04159.pdf
模型代码:https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/detectors/deformable_detr.py
前言:网上也有很多大佬讲deformable DETR 的原理和源码,但是我觉得纸上得来终觉浅,还是要自己一行代码一行代码地 debug,才能理解得更加透彻。于是有了这篇博客,方便自己后续复习。
1. 数据预处理
mmdet/models/data_preprocessors/data_preprocessor.py
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization,padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input. 将图片统一成一样的格式
"""
batch_pad_shape = self._get_pad_shape(data) # [(672, 807), (800, 800)]
data = super().forward(data=data, training=training) #
inputs, data_samples = data['inputs'], data['data_samples'] # inputs(2, 3, 800, 807)
if data_samples is not None:
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which is
# then used for the transformer_head.
batch_input_shape = tuple(inputs[0].size()[-2:]) # (596, 652)
for data_sample, pad_shape in zip(data_samples, batch_pad_shape):
data_sample.set_metainfo({
'batch_input_shape': batch_input_shape,
'pad_shape': pad_shape
})
if self.boxtype2tensor:
samplelist_boxtype2tensor(data_samples)
if self.pad_mask and training:
self.pad_gt_masks(data_samples)
if self.pad_seg and training:
self.pad_gt_sem_seg(data_samples)
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
inputs, data_samples = batch_aug(inputs, data_samples)
return {'inputs': inputs, 'data_samples': data_samples}
输入:一个 batch 中的 data(是一个字典结构)
data 的数据结构如下所示:inputs 是一个列表,里面存放的是这个 batch 中的 tensor;data_samples 也是一个列表(不需要关注)。
处理:F.pad
输出:处理之后的 data(还是一个字典结构)
对一个 batch 中的图片大小统一成同样大小。如将一个 batch 中的两张大小为(3, 576, 768)和(3, 544, 778)的图片统一成(3, 576 778)。如下图所示:
2. Backbone
这里使用的 backbone 网络为 resnet50,关于resnet50 具体的网络结构这里就不介绍了,不是本篇文章的重点。
mmdet/models/detectors/base_detr.py
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).
Returns:
tuple[Tensor]: Tuple of feature maps from neck. Each feature map
has shape (bs, dim, H, W).
"""
x = self.backbone(batch_inputs)
if self.with_neck:
x = self.neck(x)
return x
输入:(2, 3, 576, 778)
处理:resnet50
输出:[(2, 512, 72, 98), (2, 1024, 36, 49), (2, 2048, 18, 25)]
3. Neck
输入:[(2, 512, 72, 98), (2, 1024, 36, 49), (2, 2048, 18, 25)]
处理:卷积、池化、激活
输出:((2, 256, 72, 98), (2, 256, 36, 49), (2, 256, 18, 25), (2, 256, 9, 13))
mmdet/models/necks/channel_mapper.py
def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
"""Forward function."""
assert len(inputs) == len(self.convs)
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
if self.extra_convs:
for i in range(len(self.extra_convs)):
if i == 0:
outs.append(self.extra_convs[0](inputs[-1]))
else:
outs.append(self.extra_convs[i](outs[-1]))
return tuple(outs)
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
特征图大小没变,特征维度变成了 256, 结果如下:
经过self.extra_convs
之后,新增了一个(9*13)大小的特征图, 结果如下:
4. Head(核心)
def forward_transformer(self,
img_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None) -> Dict:
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
img_feats, batch_data_samples)
encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)
decoder_inputs_dict.update(tmp_dec_in)
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
head_inputs_dict.update(decoder_outputs_dict)
return head_inputs_dict
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).
Returns:
tuple[Tensor]: Tuple of feature maps from neck. Each feature map
has shape (bs, dim, H, W).
"""
x = self.backbone(batch_inputs)
if self.with_neck:
x = self.neck(x)
return x
4.1. pre_transformer
进行 transformer 前的一些准备工作,比如:
- 将img_feats([(batch_size, emb_size, h1, w1), (batch_size, emb_size, h2, w2), (batch_size, emb_size, h3, w3), (batch_size, emb_size, h4, w4)])的格式转换为(batch_size, featurre_nums, emb_size);
- 记录每一层 feature map 上哪些位置需要被 mask,即不参与计算;
- 计算每一层 feature map 上每个位置的位置编码
- 记录每一层 feature map 的大小
- 记录每一层 feature map 的开始索引位置
- 记录每一层 feature map 有效特征点的比例
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(img_feats, batch_data_samples)
输入:img_feats 为四个层级的特征图,batch_data_samples 包含当前 batch 中每张图片的相关信息
输出:
- encoder_inputs_dict
-
- feat(batch_size, featurre_nums, emb_size) 特征值
- feat_mask(batch_size, featurre_nums) 特征值 mask,即哪些位置需要忽略,因为图片有填充
- feat_pos(batch_size, featurre_nums, emb_size) 每个特征点的位置编码
- spatical_shapes(feature_levels, 2) 每层特征图大小
- level_start_index(feature_levels,) 每层特征图的开始位置
- valid_ratios(batch_size, feature_levels, 2)每层特征图上有效特征的比例
- decoder_inputs_dict
-
- memory_mask(batch_size, featurre_nums) 特征值 mask
- spatial_shapes(feature_levels, 2) 每层特征图大小
- level_start_index(feature_levels,) 每层特征图的开始位置
- valid_ratios(batch_size, feature_levels, 2)每层特征图上有效特征的比例
小结: pre_transformer 部分进行 transformer 之前的一些准备工作
masks
masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
masks 的大小为一个批次中统一的大小,这里是(576, 778),然后将每张图片实际大小的位置用 0 填充。这里的实际大小为(576, 768)和(544, 778)
mlvl_masks、mlvl_pos_embeds
mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(masks[None], size=feat.shape[-2:]).to(
torch.bool).squeeze(0))
mlvl_masks:使用双线性插值算法将(576, 778)缩放为对应特征图的大小(72, 98),(36 49),(18, 25),(9, 13)
mlvl_pos_embeds:存放的是位置编码,每个位置的编码维度为 256
# 对tensor维度进行压缩
feat_flatten = []
lvl_pos_embed_flatten = []
mask_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
batch_size, c, h, w = feat.shape
spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
# [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
feat = feat.view(batch_size, c, -1).permute(0, 2, 1) # (2, h*w, 256)
pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) # (2, h*w, 256)
# self.level_embed:(4, 256), 每一层特征图还有一个位置编码,用于区分是哪个特征图
# 特征图上每一个位置的编码+特征图整体位置编码
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) # (2, h*w, 256)
# [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
if mask is not None:
mask = mask.flatten(1) # (2, h*w)
feat_flatten.append(feat)
lvl_pos_embed_flatten.append(lvl_pos_embed)
mask_flatten.append(mask)
spatial_shapes.append(spatial_shape)
# (bs, num_feat_points, dim)
feat_flatten = torch.cat(feat_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
if mask_flatten[0] is not None:
mask_flatten = torch.cat(mask_flatten, 1)
else:
mask_flatten = None
# (num_level, 2)
spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )), # (num_level)
spatial_shapes.prod(1).cumsum(0)[:-1]))
if mlvl_masks[0] is not None:
valid_ratios = torch.stack( # (bs, num_level, 2)
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
else:
valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats),
2)
encoder_inputs_dict = dict(
feat=feat_flatten,
feat_mask=mask_flatten,
feat_pos=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
decoder_inputs_dict = dict(
memory_mask=mask_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
return encoder_inputs_dict, decoder_inputs_dict
上面代码中,self.level_embed
的 shape 为(4, 256),表示每一个特征图都有一个层级编码(256 维),用于表示属于哪层特征。
feat_flatten = torch.cat(feat_flatten, 1)
的 shape 为(2, 9387, 256),其中9387 = 72*98 + 36*49 + 18*25 + 9*13,相当于是把四层特征图拼接在了一起。如下图所示:
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
的 shape 为(2, 9387, 256),其中9387 = 72*98 + 36*49 + 18*25 + 9*13。保存的是位置编码信息和层级编码信息。如下图所示:
mask_flatten = torch.cat(mask_flatten, 1)
的 shape 为(2, 9387),保存的是每个特征图上的 mask 信息。2 表示 batch_size。有颜色的位置表示被 mask 掉的位置。
spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
的 shape 为(4, 2),保存的是每层特征图的特征大小。[72, 98], [36, 49], [18, 25], [9, 13]。
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
存放的是每层特征图的起始索引位置。[0, 7096, 8820, 9270]。
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)的 shape 为(2, 4, 2),表示每个特征图上有效位置的比例,如下图所示:
以(0.98, 1)为例:0.98 表示这个 batch 中第一张图片的(72*98)的特征图上,h 的有效比例为 0.98,w 的有效比例为 1。有什么用?后面再详细解释。
4.2. forward_encoder
encoder 部分是在干嘛呢,debug 了很多次源码,大概是理解了。用一句话描述就是对输入的特征图结合不同层的特征进行重构。而重构的方法就是每个特征点都需要考虑 4 层特征图上对应点的信息,而每层特征图上对应点的信息又是参考其周围 4 个参考点的信息得到的,且这 4 个参考点贡献的权重都是不一样的。上述过程重复多次就完成了 encoder 部分。
输入:
encoder_inputs_dict
- feat(batch_size, featurre_nums, emb_size) 特征值
- feat_mask(batch_size, featurre_nums) 特征值 mask,即哪些位置需要忽略,因为图片有填充
- feat_pos(batch_size, featurre_nums, emb_size) 每个特征点的位置编码
- spatical_shapes(feature_levels, 2) 每层特征图大小
- level_start_index(feature_levels,) 每层特征图的开始位置
- valid_ratios(batch_size, feature_levels, 2)每层特征图上有效特征的比例
输出:
encoder_outputs_dict
- memory:对输入 feature 重构之后的结果,shape 并没有发生变化(batch_size, featurre_nums, emb_size)
- feat_mask(batch_size, featurre_nums),原封不动返回
- spatical_shapes(feature_levels, 2) ,原封不动返回
mmdet/models/detectors/base_detr.py
encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
mmdet/models/detectors/deformable_detr.py
def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
feat_pos: Tensor, spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor) -> Dict:
# memory(b, num_value, num_heads*emb)举例(2, 120, 256)
memory = self.encoder(
query=feat,
query_pos=feat_pos,
key_padding_mask=feat_mask, # for self_attn
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
encoder_outputs_dict = dict(
memory=memory,
memory_mask=feat_mask,
spatial_shapes=spatial_shapes)
return encoder_outputs_dict
mmdet/models/layers/transformer/deformable_detr_layers.py
def forward(self, query: Tensor, query_pos: Tensor,
key_padding_mask: Tensor, spatial_shapes: Tensor,
level_start_index: Tensor, valid_ratios: Tensor,
**kwargs) -> Tensor:
reference_points = self.get_encoder_reference_points(
spatial_shapes, valid_ratios, device=query.device)
for layer in self.layers:
query = layer(
query=query,
query_pos=query_pos,
key_padding_mask=key_padding_mask,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reference_points=reference_points,
**kwargs)
return query
4.2.1. reference_points
输入:spatial_shapes, valid_ratios
输出:reference_points(batch_size, num_features, 4, 2)
reference_points:特征图上的每个点都有相对于每一层特征图的相对位置。
用一段代码解释reference_points 表示的含义
# 假设batch_size=1
reference_points = torch.tensor(
[[[0.1, 0.1], [0.3, 0.1], [0.5, 0.1], [0.7, 0.1], [0.9, 0.1],
[0.1, 0.3], [0.3, 0.3], [0.5, 0.3], [0.7, 0.3], [0.9, 0.3],
[0.1, 0.5], [0.3, 0.5], [0.5, 0.5], [0.7, 0.5], [0.9, 0.5],
[0.1, 0.7], [0.3, 0.7], [0.5, 0.7], [0.7, 0.7], [0.9, 0.7],
[0.1, 0.9], [0.3, 0.9], [0.5, 0.9], [0.7, 0.9], [0.9, 0.9],
[0.12, 0.12], [0.37, 0.12], [0.62, 0.12], [0.87, 0.12],
[0.12, 0.37], [0.37, 0.37], [0.62, 0.37], [0.87, 0.37],
[0.12, 0.62], [0.37, 0.62], [0.62, 0.62], [0.87, 0.62],
[0.12, 0.87], [0.37, 0.87], [0.62, 0.87], [0.87, 0.87],
[0.16, 0.16], [0.50, 0.16], [0.83, 0.16],
[0.16, 0.50], [0.50, 0.50], [0.83, 0.50],
[0.16, 0.83], [0.50, 0.83], [0.83, 0.83],
[0.25, 0.25], [0.75, 0.25],
[0.25, 0.75], [0.75, 0.75]]])
# 一个batch_size中只有一张图片的话,每一层特征图上h和w的valid_ratio都是1
valid_ratios = torch.tensor([[[1, 1], [1, 1], [1, 1], [1, 1]]])
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
print(reference_points)
以(0.1, 0.1)这个点为例,分别映射到每一层特征图上的位置如下图所示:
mmdet/models/layers/transformer/deformable_detr_layers.py
def get_encoder_reference_points(
spatial_shapes: Tensor, valid_ratios: Tensor,
device: Union[torch.device, str]) -> Tensor:
reference_points_list = []
# spatial_shapes:(72, 98), (36, 49), (18, 25), (9, 13)
for lvl, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
# [bs, sum(hw), num_level, 2]
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
上面代码中,
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device))
四个spatial_shapes 对应的ref_y, ref_x 的结果如下所示:
将横纵坐标进行归一化
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W)
将横纵坐标进行拼接
ref = torch.stack((ref_x, ref_y), -1)
将 reference_points 中的每个点映射到每层特征图
reference_points = torch.cat(reference_points_list, 1)
# [bs, sum(hw), num_level, 2]
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
下面以(4, 5), (2, 3)两层特征图大小来画图解释reference_points 的构建过程
4.2.2. DeformableDetrTransformerEncoderLayer
包含如下几层
- (self_attn): MultiScaleDeformableAttention(
- (dropout): Dropout(p=0.1, inplace=False)
- (sampling_offsets): Linear(in_features=256, out_features=256, bias=True)
- (attention_weights): Linear(in_features=256, out_features=128, bias=True)
- (value_proj): Linear(in_features=256, out_features=256, bias=True)
- (output_proj): Linear(in_features=256, out_features=256, bias=True)
4.2.2.1.MultiScaleDeformableAttention(self_attn)
输入:
- query=query, # q 矩阵( 总的特征数, 特征维度)
- query_pos=query_pos, # 位置矩阵( 总的特征数, 特征维度)
- key_padding_mask=key_padding_mask, # mask 矩阵( 2, 总的特征数),即哪些点不计算 attention
- spatial_shapes=spatial_shapes, # 不同尺度下的空间形状
- level_start_index=level_start_index, # 每一层的开始索引
- valid_ratios=valid_ratios, # 特征图有效比例
- reference_points=reference_points, # 每一个特征点的参考点坐标
输出:
query:融合多层特征之后的 query,特征维度保持不变。
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
key_padding_mask=key_padding_mask,
**kwargs)
def forward(self,
query: torch.Tensor, # query
key: Optional[torch.Tensor] = None, # query
value: Optional[torch.Tensor] = None, # query
identity: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function of MultiScaleDeformAttention.
Args:
query (torch.Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (torch.Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (torch.Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (torch.Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (torch.Tensor): The positional encoding for `query`.
Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with
shape [bs, num_key].
reference_points (torch.Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (torch.Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
torch.Tensor: forwarded results with shape
[num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
# 通过一个线性变换层就得到value了,唐宇迪大佬说的神奇,确实神奇
value = self.value_proj(value)
if key_padding_mask is not None:
# 被mask的部分,也就是被填充的部分就不需要计算value值
value = value.masked_fill(key_padding_mask[..., None], 0.0)
# 矩阵变换(b, num_value, num_heads, emb)
value = value.view(bs, num_value, self.num_heads, -1)
# 再通过一个线性变换成就得到采样点的偏移值了(b, num_query, num_heads, num_levels, num_points, 2)
# 有点离谱了
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
# 再通过一个全连接层就得到采样点的权重了(b, num_query, num_heads, num_levels*num_points)
# 一个特征点会采样4个采样点,而这4个采样点的权重就是attention_weights
# 这个query简直太腻害了,做什么都可以
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
# 将权重进行softmax
attention_weights = attention_weights.softmax(-1)
# 矩阵变换(b, num_query, num_heads, num_levels, num_points)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
# (w, h)
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
# 核心代码(b, num_query, num_heads, num_levels, num_points, 2)
# 看画图解释
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)):
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
# (b, num_query, num_heads*embed_dims)
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
# 线性变换
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
# 加上恒等映射
return self.dropout(output) + identity
4.2.2.1.1. multi_scale_deformable_attn_pytorch
多尺度的 attention
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor, value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor) -> torch.Tensor:
"""CPU version of multi-scale deformable attention.
Args:
value (torch.Tensor): The value has shape
(bs, num_keys, num_heads, embed_dims//num_heads)
value_spatial_shapes (torch.Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (torch.Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (torch.Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
torch.Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
# [(bs, num_keys1, num_heads, embed_dims//num_heads), (bs, num_keys2, num_heads, embed_dims//num_heads),
# (bs, num_keys3, num_heads, embed_dims//num_heads), (bs, num_keys4, num_heads, embed_dims//num_heads), ]
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
# 将sampling_locations的值缩放在-1到1之间,因为F.grid_sample要求在-1到1之间
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
# 每个特征点
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
下图解释了 output 的实现过程
4.3. pre_decoder
decoder 前的准备工作。包括:
生成候选框、重构 memory:根据输入的memory 生成候选框坐标,每个特征点生成四个坐标值output_proposals,分别表示(cx, cy, w, h);并且对输入的 memory 进行重构,得到 out_memory;
分类分支、回归分支:对out_memory 经过分类分支得到每个特征点属于每个类别的得分值,其中第一个维度表示属于前景的得分值。对out_memory 经过回归分支得到每个特征点需要预测的框的坐标值,将这个坐标值+output_proposals,得到enc_outputs_coord_unact;
筛选候选框:从enc_outputs_coord_unact 中筛选num_queries 个候选框。
生成 query 和 query_pos:将筛选出来的num_queries 个坐标值进行位置编码,然后经过 Linear 和 LayerNorm 层,得到 512 为特征。其中,前 256 个维度作为 query 位置编码,后 256 个维度作为 query 特征。
tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)
输入:
encoder_outputs_dict
- memory:重构之后的 query(b, num_querys, num_heads*embed_dims)
- feat_mask:特征 mask(b, num_query)
- spatial_shapes:每层特征图大小(4, 2)
输出:
decoder_inputs_dict
- query:对筛选出来的坐标值经过 Linear 和 LayerNorm 层,得到的 256 维度的特征;
- query_pos:对筛选出来的坐标值经过 Linear 和 LayerNorm 层,得到的 256 维度的特征,用来表示位置特征;
- memory:没有做任何处理,直接输出;
- reference_points:将筛选出来的num_queries个坐标值进行激活,将其缩放在0-1之间,就得到了reference_points
head_inputs_dict
- enc_outputs_class:每个特征点属于每个类别的得分值(bs, num_querys, num_calsses)
- enc_outputs_coord:激活之后的候选框坐标(bs, num_querys, 4)
4.3.1. 两阶段
4.3.1.1. gen_encoder_output_proposals
对重构后的 query(memory) 生成出候选框,即每个 query 点都会回归出四个坐标点(cx, cy, w, h)。这四个坐标点是偏移量。
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, memory_mask, spatial_shapes)
def gen_encoder_output_proposals(
self, memory: Tensor, memory_mask: Tensor,
spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]:
"""Generate proposals from encoded memory. The function will only be
used when `as_two_stage` is `True`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
Returns:
tuple: A tuple of transformed memory and proposals.
- output_memory (Tensor): The transformed memory for obtaining
top-k proposals, has shape (bs, num_feat_points, dim).
- output_proposals (Tensor): The inverse-normalized proposal, has
shape (batch_size, num_keys, 4) with the last dimension arranged
as (cx, cy, w, h).
"""
bs = memory.size(0)
proposals = []
_cur = 0 # start index in the sequence of the current level
for lvl, HW in enumerate(spatial_shapes):
H, W = HW
if memory_mask is not None:
mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(
bs, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0],
1).unsqueeze(-1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0],
1).unsqueeze(-1)
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
else:
if not isinstance(HW, torch.Tensor):
HW = memory.new_tensor(HW)
scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(
0, W - 1, W, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
proposals.append(proposal)
_cur += (H * W)
output_proposals = torch.cat(proposals, 1)
# do not use `all` to make it exportable to onnx
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)).sum(
-1, keepdim=True) == output_proposals.shape[-1]
# inverse_sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
if memory_mask is not None:
output_proposals = output_proposals.masked_fill(
memory_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float('inf'))
output_memory = memory
if memory_mask is not None:
output_memory = output_memory.masked_fill(
memory_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
output_memory = self.memory_trans_fc(output_memory)
output_memory = self.memory_trans_norm(output_memory)
# [bs, sum(hw), 2]
return output_memory, output_proposals
上面代码中,output_proposals = torch.log(output_proposals / (1 - output_proposals))函数的图像如下。如果输入值越靠近 0 或者越靠近 1,那么输出值就是无穷小或者无穷大。
以下面的示例代码,画图解释output_proposals 的构建过程。
import torch
from torch import nn
torch.manual_seed(42)
memory = torch.rand((1, 26, 8))
memory_mask = torch.tensor([[False, False, False, False, True,
False, False, False, False, True,
False, False, False, False, True,
False, False, False, False, True,
False, False, True,
False, False, True]])
spatial_shapes = torch.tensor([[4, 5], [2, 3]])
memory_trans_fc = nn.Linear(8, 8)
memory_trans_norm = nn.LayerNorm(8)
def gen_encoder_output_proposals(memory, memory_mask, spatial_shapes):
bs = memory.size(0)
proposals = []
_cur = 0 # start index in the sequence of the current level
for lvl, HW in enumerate(spatial_shapes):
H, W = HW
if memory_mask is not None:
mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(
bs, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0],
1).unsqueeze(-1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0],
1).unsqueeze(-1)
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
else:
if not isinstance(HW, torch.Tensor):
HW = memory.new_tensor(HW)
scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(
0, W - 1, W, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
proposals.append(proposal)
_cur += (H * W)
output_proposals = torch.cat(proposals, 1)
# do not use `all` to make it exportable to onnx
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)).sum(
-1, keepdim=True) == output_proposals.shape[-1]
# inverse_sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
if memory_mask is not None:
output_proposals = output_proposals.masked_fill(
memory_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float('inf'))
output_memory = memory
# 两步筛选:先对mask位置填充0
if memory_mask is not None:
output_memory = output_memory.masked_fill(
memory_mask.unsqueeze(-1), float(0))
# 再对< 0.01且 > 0.99的位置填充0
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
# 线性层
output_memory = memory_trans_fc(output_memory)
# norm层
output_memory = memory_trans_norm(output_memory)
# [bs, sum(hw), 2]
return output_memory, output_proposals
output_memory, output_proposals = gen_encoder_output_proposals(memory, memory_mask, spatial_shapes)
output_memory
output_proposals
4.3.1.2. cls_branches
分类分支: 通过一个线性层得到每个 query 点属于每一个类别的得分值
# (bs, num_query, num_class),如(2, 120, 80)
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
4.3.1.3. reg_branches
回归分支:通过一个线性层得到每个 query 的四个坐标值
# (bs, num_query, 4),如(2, 120, 4)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
# 经过激活函数之后的结果,用来将坐标值缩放到0-1之间
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
4.3.1.4. 筛选 proposals
根据enc_outputs_class 的第一个维度表示属于前景(foreground)的得分值,因此,通过根据第一个维度值的大小来筛选出num_queries 个 proposals。
# (bs, num_queries),如(2, 100)
topk_proposals = torch.topk(
enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
# 然后根据筛选出来的索引值,从enc_outputs_coord_unact中筛选对应索引的坐标值
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
# (bs, num_queries, 4),如(2, 100, 4)
topk_coords_unact = topk_coords_unact.detach()
4.3.1.5. reference_points
将筛选出来的num_queries个坐标值进行激活,将其缩放在0-1之间,就得到了reference_points。这里的reference_points 同 encoder 阶段的作用一样,用于表示每个特征点在每一层特征图上的相对位置。
# (bs, num_queries, 4)
reference_points = topk_coords_unact.sigmoid()
4.3.1.6. query 和query_pos
将筛选出来的num_queries 个坐标值进行位置编码,然后经过 Linear 和 LayerNorm 层,得到 512 为特征。其中,前 256 个维度作为 query 位置编码,后 256 个维度作为 query 特征。
pos_trans_out = self.pos_trans_fc(
self.get_proposal_pos_embed(topk_coords_unact))
pos_trans_out = self.pos_trans_norm(pos_trans_out)
query_pos, query = torch.split(pos_trans_out, c, dim=2)
4.4. forward_decoder
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
4.4.1. self_attn
代码和 encoder 阶段的代码是一样的,只是这里的 query 是经过筛选之后的 N 个候选框。表示什么意思呢,即这 N 个候选框之间进行 attention 计算
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=self_attn_mask,
**kwargs)
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
identity (Tensor): This tensor, with the same shape as x,
will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Defaults to None.
Returns:
Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(f'position encoding of key is'
f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
out = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
进入self.attn 方法,通过 pytorch 内置的 multi_head_attention_forward 函数来实现 self_attn
# attn_output(300, 2, 256)
# attn_output_weights(2, 300, 300)
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, average_attn_weights=average_attn_weights)
4.4.2. cross_attn
query = self.cross_attn(
query=query,
key=key,
value=value,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
**kwargs)
query 是经过 decoder 阶段 self_attn 之后的 query,而 value 则是 encoder 阶段的 value,即encoder 阶段 self_attn 之后的 query,因此是 cross_attn。代码和 encoder 阶段的MultiScaleDeformableAttention 是一样的。
5. Loss
debug 代码太累了,调整一下再战...