上图为swin_transformer 的主体框架结构,模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。
patch partition
首先是patch partition结构,该模块的作用是对输入的原始图片通过conv2d进行裁剪为patch_size*patch_size大小的块(不是window_size),设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。代码如下:
class PatchEmbed(nn.Module):
# 实质为对特征图做4x4的卷积
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # [window_size, window_size]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1] # patch的个数
self.in_chans = in_chans # Number of input image channels
self.embed_dim = embed_dim # Number of linear projection output channels
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim) # LN
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# flatten(2):从第2个维度开始展开,将后面的维度转化为一维
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
Patch_Merging
该模块接在每个stage之后,最后一个后面不接的作用是将特征图进行切片操作,达到下采样的目的,使尺寸缩小为原来的一半,通道增加为原来的4倍为4c,再通过一个线性层之后将通道降为原来的2倍为2c。
class PatchMerging(nn.Module):
# 该操作类似于yolov5里面的focus操作
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)