一、关键前向推理函数
在官方代码中,MbViTBkV2类的前向推理代码如下:
def forward_spatial(self, x: Tensor) -> Tensor:
x = self.resize_input_if_needed(x)
fm = self.local_rep(x)
# print(f'local_rep output: {fm.shape}')
# convert feature map to patches
if self.enable_coreml_compatible_fn:
patches, output_size = self.unfolding_coreml(fm)
else:
patches, output_size = self.unfolding_pytorch(fm)
# print(f'patches output: {patches.shape}, output_size: {output_size}')
# learn global representations on all patches
patches = self.global_rep(patches)
# print(f'after global patches output: {patches.shape}')
# [B x Patch x Patches x C] --> [B x C x Patches x Patch]
if self.enable_coreml_compatible_fn:
fm = self.folding_coreml(patches=patches, output_size=output_size)
else:
fm = self.folding_pytorch(patches=patches, output_size=output_size)
fm = self.conv_proj(fm)
# print(f'conv_proj fm output: {fm.shape}')
return fm
第一部分代码将输入特征图进行填充,特征图大小要满足能够整除分片大小。
x = self.resize_input_if_needed(x)
def resize_input_if_needed(self, x):
# print(f'before resize')
batch_size, in_channels, orig_h, orig_w = x.shape
if orig_h % self.patch_h != 0 or orig_w % self.patch_w != 0:
print(f'if resize resize_input_if_needed x:{x.shape}')
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
x = F.interpolate(
x, size=(new_h, new_w), mode="bilinear", align_corners=True
)
return x
第二部分代码,进行局部特征提取,包括一个深度可分离卷积模块。
fm = self.local_rep(x)
...
conv_3x3_in = DWConv(in_channels, in_channels, 3, 1)
conv_1x1_in = nn.Conv2d(in_channels, cnn_out_dim, 1, 1, padding=0)
self.local_rep = nn.Sequential(conv_3x3_in, conv_1x1_in)
...
第三部分代码,变换特征图大小,其中有两个方法,对应不同框架,例如输入特征图大小为[1×64×80×80]
,分片大小为2×2
,经过变换后是[1×64×4×1600]
。
if self.enable_coreml_compatible_fn:
patches, output_size = self.unfolding_coreml(fm)
else:
patches, output_size = self.unfolding_pytorch(fm)
def unfolding_pytorch(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
batch_size, in_channels, img_h, img_w = feature_map.shape
# [B, C, H, W] --> [B, C, P, N]
# unfold函数作用为滑动窗口,每一次滑动取出对应的元素
patches = F.unfold(
feature_map,
kernel_size=(self.patch_h, self.patch_w),
stride=(self.patch_h, self.patch_w),
)
patches = patches.reshape(
batch_size, in_channels, self.patch_h * self.patch_w, -1
)
return patches, (img_h, img_w)
第四部分代码为全局特征提取,核心算法为论文中的可分离注意力机制。
patches = self.global_rep(patches)
...
self.global_rep, attn_unit_dim = self._build_attn_layer(
d_model=attn_unit_dim,
ffn_mult=ffn_multiplier,
n_layers=n_attn_blocks,
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout,
attn_norm_layer=attn_norm_layer,)
....
最后一部分代码为转换特征图大小为[B×C×H×W]
格式方便后续进行卷积操作。