之前的研究方向主要是NLP,最近因为接触多模态,对遗忘的CV知识还需要进一步的巩固
本篇文章跳过过多的废话,从源码角度更直观的理解Swin Transformer,很多人看完论文或者一些讲解,对于代码的编写以及整体的流程还是懵的,所以本篇从代码的输入->输出一步一步讲解各个模块,从而使得论文和源码串联起来。(文章最后放有原文的阅读哦)
话不多说,和大家一起学习,直接开干!
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
paper:https://arxiv.org/abs/2103.14030v1?ref=hackernoon.com
code:https://github.com/microsoft/Swin-Transformer
1 Swin Transformer出发点
1.1Transformer在CV领域的应用遇到的问题
- 图像的尺度变化范围非常大,且不符合标准的固定尺度。
- 由于Transformer的计算复杂度是与token数量的平方成正比的,如果将每个像素值视作一个token,其计算量将变得非常庞大。
1.2.Swin Transformer解决上述问题的方法
- Swin Transformer提出了"Shifted Window Mechanism"的方法来解决图像尺度变化范围大且不符合标准的固定尺度问题。即通过将输入的图像分割成一组具有不同尺度的小块(称为局部窗口),然后利用Transformer对这些小块进行处理,从而有效处理尺度变化较大的图像。
- Swin Transformer使用了分层的机制,每个层级的尺度略有不同,这种分层的设计使得模型可以在不同的层级上进行处理,从而减少了整体计算的复杂度。其次,而Swin Transformer采用了局部窗口的注意力机制,即每个位置只与其周围的局部窗口进行关联。这种方式大大减少了注意力机制的计算量。
1.3. Swin Transformer与VIT的区别
- Swin Transformer对图像进行不同倍数的下采样(如4倍、8倍、16倍),可以得到不同尺度的特征表示,模型可以学会在多尺度下理解和检测目标,适应不同尺度的目标区域。这种多尺度训练策略有助于提高模型在检测大目标和小目标时的性能。
- Vision Transformer(ViT)只使用单一的16倍下采样特征。这意味着ViT只能从单一分辨率的特征中学习,对于不同尺度的目标可能缺乏足够的多样性表示,可能会影响其在小目标或大目标检测任务上的性能。
2. Swin Transformer框架
Swin Transformer模型采取层次化的设计,一共包含4个Stage(可以理解为需要循环执行4层,每层包含不等的Block)。
2.1. Swin Transformer组件
- Patch Partition: 将输入的图像分割成图像块(patch),并将每个图像块转换成嵌入向量作为 Transformer 模型的输入。
- Patch Merging: 主要作用是将相邻层级的特征图进行合并,以实现多尺度的特征融合和降低计算复杂度。
- W-MSA: 窗口自注意力机制,它在处理图像时将注意力局限于特定的窗口,从而减少了计算和存储开销。
- SW-MSA: 滑动窗口自注意力机制,通过局部平移操作(shift)将窗口沿着宽度和高度方向移动,以获取不同的局部邻域信息,这样可以确保W-MSA中不同窗口内的信息之间能够交互,并减少计算量。
Note: W-MSA与SW-MSA是成对出现的,也就说每层的Swin Transformer Block都必须是偶数
2.2. Swin Transformer流程
-
输入的图像经过Patch Partition被切成一个个图像块,并将每个图像块转换成嵌入向量Embedding。
-
图像嵌入被送到Swin Transformer Stage中循环执行Stage内的操作。(4个Stage,每个Stage内2*-Block),具体一个Stage内的操作:
-
图像嵌入送入到第一个Block中执行LN->W-MSA(relative_position_bias)->LN->MLP操作,重构图像嵌入
-
重构的图像嵌入送入到第二个Block中执行LN->SW-MSA(relative_position_bias)->LN->MLP操作,得到新的图像嵌入
-
循环执行完Stage内所有的Block
-
新的图像嵌入被送入到Patch Merging中执行下采样,减少图片的分辨率。
-
Note: 该流程是按照代码的流程来执行的,Swin-T框架图是先执行Patch Merging->Block,但实际是Block->Patch Merging。
3. 从源码角度展开讲解Swin Transformer
对Swin Transformer中各组件的讲解完全按照代码的执行流程来,这方便大家看完后,既明白了Swin
Transformer框架的细节,也理解了Swin Transformer的执行逻辑和过程。
首先放入Swin Transformer框架完整的执行代码:
class SwinTransformer(nn.Module):
def __init__(
self,pretrain_img_size=224, patch_size=4,in_chans=3 embed_dim=96,
depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],window_size=7,mlp_ratio=4.0, qkv_bias=True,qk_scale=None, drop_rate=0.0,attn_drop_rate=0.0,drop_path_rate=0.2,norm_layer=nn.LayerNorm, ape=False,patch_norm=True,out_indices=(0, 1, 2, 3), frozen_stages=-1, dilation=False,use_checkpoint=False,):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.dilation = dilation
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(patch_size=patch_size,in_chans=in_chans,embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None,)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [pretrain_img_size[0] // patch_size[0],pretrain_img_size[1] // patch_size[1]]
self.absolute_pos_embed = nn.Parameter( torch.zeros(1, embed_dim, patches_resolution[0],patches_resolution[1]))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
# prepare downsample list
downsamplelist = [PatchMerging for i in range(self.num_layers)]
downsamplelist[-1] = None
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
if self.dilation:
downsamplelist[-2] = None
num_features[-1] = int(embed_dim * 2 ** (self.num_layers -1)) // 2
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=num_features[i_layer],depth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,qk_scale=qk_scale,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],norm_layer=norm_layer,downsample=downsamplelist[i_layer],use_checkpoint=use_checkpoint,)
self.layers.append(layer)
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
self._freeze_stages()
def _freeze_stages(self):
...
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3,1,2).contiguous()
outs.append(out)
# collect for nesttensors
outs_dict = {}
for idx, out_i in enumerate(outs):
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(),size=out_i.shape[-2:]).to(torch.bool[0]
outs_dict[idx] = NestedTensor(out_i, mask)
return outs_dict
对SwinTransformer(nn.Module)
的初始化参数说明:
pretrain_img_size
: 用于训练预训练模型的输入图像大小。默认为224。patch_size
: Patch 的大小。默认为4。in_chans
: 输入图像的通道数。默认为3。embed_dim
: 线性投影输出通道数。默认为96。depths
: 每个 Swin Transformer 阶段的深度。num_heads
: 每个阶段的注意力头数。window_size
: 窗口大小。默认为7。mlp_ratio
: mlp 隐藏维度与嵌入维度的比率。默认为4。qkv_bias
: 如果为 True,将查询、键、值添加可学习的偏置。默认为 True。qk_scale
: 如果设置了,则覆盖默认的 qk 缩放 head_dim ** -0.5。drop_rate
: Dropout 比率。attn_drop_rate
: 注意力 Dropout 比率。默认为0。drop_path_rate
: 随机深度比率。默认为0.2。norm_layer
: 归一化层。默认为nn.LayerNorm。ape
: 如果为 True,将绝对位置嵌入添加到补丁嵌入中。默认为 False。patch_norm
: 如果为 True,在补丁嵌入后添加归一化。默认为 True。out_indices
: 哪些阶段的输出。frozen_stages
: 要冻结的阶段(停止梯度和设置评估模式)。-1 表示不冻结任何参数。use_checkpoint
: 是否使用检查点来节省内存。默认为 False。dilation
: 如果为 True,则输出大小为16倍降采样,否则为32倍降采样。
假设我们输入的图像大小是[3,800,800],经过batch构造最终的输入为[1,3,800,800],下面按照2.2中的流程开始执行
3.1.Patch Partition
在执行swin-transformer Stage之前,需要对输入的图像进行分块,并将每个块通过卷积操作投影到一个低维向量空间中,得到特征嵌入x,以满足transformer的输入要求:
-
首先,代码对输入的x进行填充操作。通过获取x的尺寸H和W,判断W是否不能被self.patch_size[1]整除,如果不能整除,则对x应用水平填充。同理,判断H 。填充操作旨在确保输入图像的高度和宽度都能被patch_size整除。(patch_size=(4,4))
-
接下来,将填充后的图像x传递给卷积层self.proj进行分块。self.proj作为一个卷积层,将输入图像的每一个路径块(大小为patch_size)进行卷积操作,将其投影到一个低维度的向量空间中。投影操作将会改变输入图像的通道数,使之变为embed_dim。
维度变换:[1,3,800,800]->(self.proj)->[1,96,200,200],self.proj=Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
-
然后,如果存在标准化操作self.norm,则对投影后的特征图进行处理。首先,获取特征图的尺寸Wh和Ww。然后,将特征图展平,并执行转置操作以进行标准化操作。最后,将标准化后的结果-重新调整尺寸,变为[batch_size, embed_dim, Wh, Ww]的形状。
维度变换:[1,96,200,200]->(flatten)->[1,96,40000]->(transpose(1,2))->[1,40000,96]->(view)->[1,96,200,200]
通过程序实现可以发现其并没有使用nn.Linear转换输入通道数,而是使用nn.Conv2d在进行patches转换时同时更换了通道数。
x = self.patch_embed(x)
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
x = x.flatten(2).transpose(1, 2)
return x
3.2.开始执行封装好的swin-Transformer Stage操作
在执行之前需要将图像嵌入xx 进行扁平化(flatten)和drop操作:
x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x)
维度变换:[1,96,200,200]->(flatten/drop)->[1,40000,96]
swin-Transformer Stage中的Block和Patch Merging被封装到BasicLayer模块中,具体代码如下:
class BasicLayer(nn.Module):
def __init__(
self,
dim,depth, num_heads, window_size=7,mlp_ratio=4.0,qkv_bias=True,qk_scale=None,drop=0.0,attn_drop=0.0,drop_path=0.0,norm_layer=nn.LayerNorm,downsample=None,use_checkpoint=False,):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,)for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None),)
w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None),)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
在执行blocks之前,首先计算了用于SW-MSA的注意力掩码,这部分的原理以及代码将在后续介绍,本小节先行跳过。
跳过SW-MSA的mask计算之后,开始执行block内操作,首先是W-MSA。
3.2.1.Block之W-MSA
W-MSA 模块是指 Window-based Multi-head Self Attention 模块,相比传统的全连接 self-attention,W-MSA 使用了窗口化机制以降低计算复杂度。在 W-MSA 模块中,输入的特征图被划分成多个窗口,并在每个窗口内进行 self-attention 计算,以捕捉局部的交互信息。
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T q + B ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{q}}+B)V Attention(Q,K,V)=Softmax(qQKT+B)V
其中,B是相对位置编码,这也是除了窗口化之外,另一个区别于原始Attention的地方。
具体BLOCK-W-MSA逻辑/代码流程(包含了W-MSA的处理):
1.首先,从输入的x
中获取其维度信息,包括batch大小(B),高度(H),宽度(W),和通道数(C)。然后,将输入特征 x 进行归一化处理(self.norm1),再将其 reshape 成 B, H, W, C 的形状。最后,根据预设的窗口大小(self.window_size
:7)来对x
进行padding操作,以使得高宽维度(Hp,Wp)成为窗口大小的整数倍,方便后续操作。
维度变换(x->x) : [1,40000,96]->(self.norm1)->[1,40000,96]->(x.view)->[1, 200, 200, 96]->(F.pad)->[1, 203, 203, 96]
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
2.然后,将特征张量shifted_x=x
进行窗口分割操作,将其分割成多个窗口,将原本的张量从 N H W C
, 划分成 num_windows*B, window_size, window_size, C
,其中 num_windows = H*W / (window_size*window_size),即窗口的个数。窗口分割的目的就是使得每个窗口之间的像素点只能与该窗口中的其他像素点进行内积从而获得信息,以达到W-MSA中W的作用。
维度变换(x->x_windows):[1, 203, 203, 96]->(window_partition.x.view)->[1, 29, 7, 29, 7, 96]->(window_partition.x.permute)->[841*1, 7, 7, 96]->(x_windows.view)->[841, 49, 96]
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) #[1, 29, 7, 29, 7, 96]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) #[841, 7, 7, 96]
return windows
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
3.将分割后的窗口特征张量x_windows
进行自注意力计算:
attn_windows = self.attn(x_windows, mask=attn_mask)
-
计算查询、键、值的映射(通过self.qkv方法),并对结果进行reshape、permute操作,将维度重新排列为
[3, num_windows*B, num_heads,window_size*window_size, C//num_heads]
。这样做是为了便于后面的注意力计算,并分配给q,k,v
维度变换(x_windows->q,k,v):[841* 1, 49, 96]->(self.qkv.reshape.permute)->[3, 841* 1, 3, 49, 32]->(qkv[0/1/2])->[841* 1, 3, 49, 32]-
B_, N, C = x.shape
qkv = (self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
-
对
q
乘以一个scale
缩放系数,然后与k
(最后两个维度调换)进行相乘。得到形状为(numWindows*B, num_heads, window_size*window_size,window_size*window_size)
的attn
张量。维度变换(q->attn):[841* 1,3, 49, 32]->(q @ k)->[841* 1, 3, 49, 49]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
-
计算相对位置偏差编码(公式中的B),并将其与注意力张量attn相加。(相对位置偏差编码后续再展开讲解)
维度变换(attn->attn):[841* 1,3, 49, 49]->(q @ k)->[841* 1, 3, 49, 49]
relative_position_bias=...
attn = attn + relative_position_bias.unsqueeze(0)
-
最后就是对注意力张量attn执行softmax以得到权重,再通过drop操作,与V矩阵相乘,并进行维度变换,将结果变换回原来的形状,进行MLP层和drop操作。
维度变换(attn->x):[841* 1,3, 49, 49]->(attn @ v.transpose.reshape)->[841*1, 49, 96]->(self.proj)->[841 * 1, 49, 96]
if mask is not None:
...
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
4.将计算得到的注意力窗口特征重新组合成原始形状,然后根据之前的位移操作恢复原来的位置。
维度变换(attn_windows->shifted_x):[841 * 1, 49, 96]->(window_reverse)->[1, 203, 203, 96]
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
attn_windows = attn_windows.view(-1, self.window_size, self.window_size,C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
5.最后,如果进行了padding操作,则通过切片操作将特征张量x
恢复到原始的高度和宽度。将特征张量x
重新展平成shape为(B,H×W,C)的形状。通过短路连接的方式将shortcut
加到计算得到的特征上,并对其进行Drop Path操作。然后,将特征x
输入到多层感知机(MLP)中进行非线性变换,并再次对其进行Drop Path操作。最后,将计算得到的特征返回。
维度变换(shifted_x->x):[1, 203, 203, 96]->(window_reverse)->[1, 40000, 96]
if self.shift_size > 0:
...
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
上面介绍了一个W-MSA的Block的整体流程,SW-MSA的Block只是将W-MSA换成了SW-MSA,并多执行了特征的平移反转以及atten的mask操作,后续只对这缺失的一部分进行讲解。
3.2.2. 相对位置编码
3.2.1在介绍W-MSA的时候对其相对位置编码进行了跳过,本节将详细介绍相对位置编码。
Swin-T 网络在 Attention 计算中引入相对位置偏置机制,其准确度能够提高 1.2%~2.3% 不等。以 2×2 的特征图为例,在计算self-attention时,每个token都要与所有的token计算QK值,如下图所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。具体计算逻辑如下图所示:
上图最后生成的是相对位置索引,relative_position_index.shape= M 2 ∗ M 2 M^2*M^2 M2∗M2 (M是给定的窗口大小),在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。
具体的代码实现:
1.定义了一个相对位置偏差参数表:
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=0.02)
2.定义不参与网络学习的变量的相对位置索引:
-
首先,使用 torch.arange 函数创建了两个张量 coords_h 和 coords_w,分别表示窗口大小的高度和宽度范围内的坐标。
coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1])
-
然后,利用 torch.meshgrid 函数将 coords_h 和 coords_w 组合起来,生成一个二维坐标矩阵 coords,其中包含了所有可能的坐标组合。coords 的 shape 是 (2, Wh, Ww),其中 Wh 和 Ww 分别代表窗口高度和宽度的大小。
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
-
接下来,通过 torch.flatten 函数将 coords 展平为一维张量 coords_flatten,shape 为 (2, WhWw),表示所有坐标点的组合。
coords_flatten = torch.flatten(coords, 1)
-
计算相对坐标 relative_coords,这里通过广播操作实现了两两坐标之间的相对位置计算。计算结果的 shape 是 (2, WhWw, WhWw),因为对每个坐标点,需要计算它与所有其他坐标点的相对位置。
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous()
-
调整 relative_coords 张量的维度顺序,并加上偏移量(self.window_size[0] - 1 和 self.window_size[1] - 1)以确保相对坐标的起始位置是从 0 开始。
relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1
-
根据 Swin Transformer 中相对位置索引的计算规则,通过对相对坐标进行数据变换和求和操作,得到最终的相对位置索引矩阵 relative_position_index,shape 为 (WhWw, WhWw)。
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1)
-
最后,通过 self.register_buffer 将计算得到的相对位置索引矩阵 relative_position_index 注册为网络的缓冲区(buffer),以便在网络的前向传播过程中进行使用。
self.register_buffer("relative_position_index", relative_position_index)
在前向计算中,相对位置偏置(relative_position_bias)被计算为相对位置索引矩阵(relative_position_index)对应的预先计算好的相对位置偏置表格(relative_position_bias_table)中的值:
- 首先,利用 relative_position_index.view(-1) 将相对位置索引矩阵展平为一维张量,并使用它来索引相对位置偏置表格 relative_position_bias_table 中的值。这样得到的 relative_position_bias 是一个形状为
(window_size*window_size, window_size*window_size, nhead)
的三维张量。 - 接着,通过 .permute(2, 0, 1) 操作对 relative_position_bias 进行维度交换,将nhead维度移动到最前面,得到新的相对位置偏置张量。
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
3.2.3. Block之SW-MSA
虽然 W-MSA 通过划分 Windows 的方法减少了计算量,但是由于各个 Windows 之间无法进行信息的交互,因此可以看作其“感受野”缩小,无法得到较全局准确的信息从而影响网络的准确度。为了解决不重叠窗口之间没有关联的问题,采用了shifted window的方法。
上图左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。
在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。具体做法:
论文提出了cyclic shift方法,通过上图的方式可以在保证不重叠窗口间有联系的基础上不增加窗口的个数,新的窗口可能会由之前不相关的自窗口构成,为了保证shifted window self-attention计算的正确性,只能计算相同子窗口的self-attention,不同子窗口的self-attention结果要归0,在进行cyclic shift之前,需要给子窗口进行编码,编码之后通过torch.roll对窗口进行滚动,达到cyclic shift的效果:
在(5,3)(7,1)(8,6,2,0)组成的新窗口中,只有相同编码的部分才能计算self-attention,不同编码位置间计算的self-attention需要归0,根据self-attention公式,最后需要进行Softmax操作,不同编码位置间计算的self-attention结果通过mask加上-100,在Softmax计算过程中,Softmax(-100)无线趋近于0,达到归0的效果。
MASK实现
通过设置合理的mask,让Shifted Window Attention
在与Window Attention
相同的窗口个数下,达到等价的计算结果,
Mask是初始就给定的,即只有特征发生滚动去迎合MASK
具体代码实现:
-
先计算了输入图像的高度 Hp 和宽度 Wp,使其能够被 window_size 整除。然后,创建了一个形状为 (1, Hp, Wp, 1) 的全零张量 img_mask,用于存储图像的掩码信息。
Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
-
接下来,定义了两个切片对象 h_slices 和 w_slices,分别表示高度和宽度上的切片范围。这些切片范围将被用于将图像掩码赋值给 img_mask 张量的不同部分。
h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None),) w_slices = ( slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None),)
-
然后,使用两个嵌套的循环遍历 h_slices 和 w_slices 中的切片范围,并为每个切片位置赋予一个递增的值 cnt。这样就可以将相应的索引值赋给 img_mask 张量,以表示每个窗口的位置。
cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1
-
接下来,使用 window_partition 函数将 img_mask 张量划分为大小为 window_size 的窗口,得到形状为 (nW, window_size, window_size, 1) 的张量 mask_windows。将 mask_windows 重塑为形状为 (nW, window_size * window_size) 的二维张量。
mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
-
然后,通过计算 mask_windows 与自身转置之间的差异来计算注意力掩码(attn_mask)。将 attn_mask 中不为零的位置用 -100.0 替换,将为零的位置用 0.0 替换。这样可以为注意力机制建立一个掩码,以在自注意力计算的过程中过滤不相关的位置。
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
Block之SW-MSA具体代码实现:
- 同3.2.1中的具体BLOCK-W-MSA逻辑/代码流程的第一步。
- 使用
torch.roll
函数,对张量x
进行了循环移位操作。具体来说,它将张量x
在两个维度上分别向左循环移动了-self.shift_size
个位置
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
- 同3.2.1中的具体BLOCK-W-MSA逻辑/代码流程的第二步。
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) #[1, 29, 7, 29, 7, 96]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) #[841, 7, 7, 96]
return windows
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
- 将分割后的窗口特征张量
x_windows
进行自注意力计算:
attn_windows = self.attn(x_windows, mask=attn_mask)
- 计算查询、键、值的映射(通过self.qkv方法),并对结果进行reshape、permute操作,并分配给
q,k,v
B_, N, C = x.shape
qkv = (self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
- 对
q
乘以一个scale
缩放系数,然后与k
(最后两个维度调换)进行相乘,得到attn
张量。
q = q * self.scale
attn = q @ k.transpose(-2, -1)
- 计算相对位置偏差编码(公式中的B),并将其与注意力张量attn相加。
relative_position_bias=...
attn = attn + relative_position_bias.unsqueeze(0)
- 将注意力权重张量
attn
重塑形状为(B_ // nW, nW, self.num_heads, N, N)
,并加上mask
的扩展维度。重新调整attn
的形状为(-1, self.num_heads, N, N)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) +mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
- 最后就是对注意力张量attn执行softmax以得到权重,再通过drop操作,与V矩阵相乘,并进行维度变换,将结果变换回原来的形状,进行MLP层和drop操作。
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
4.将计算得到的注意力窗口特征重新组合成原始形状,然后根据之前的位移操作恢复原来的位置。
attn_windows = attn_windows.view(-1, self.window_size, self.window_size,C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
5.将移位前的张量 x
恢复到原始状态,具体来说,它将张量 shifted_x
在两个维度上分别向右循环移动了 self.shift_size
个位置。
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
6.最后,如果进行了padding操作,则通过切片操作将特征张量x
恢复到原始的高度和宽度。将特征张量x
重新展平成shape为(B,H×W,C)的形状。通过短路连接的方式将shortcut
加到计算得到的特征上,并对其进行Drop Path操作。然后,将特征x
输入到多层感知机(MLP)中进行非线性变换,并再次对其进行Drop Path操作。最后,将计算得到的特征返回。
if self.shift_size > 0:
...
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
(加粗部分是SW-MSA比W-MSA多出的操作)
3.3.3.Patch Merging
在执行完Block后,需要执行Patch Mering操作,该模块的作用是做降采样,用于缩小分辨率,调整通道数进而形成层次化的设计,同时也能节省一定运算量。
每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。具体操作:
- 获取输入张量
x
的形状信息,分别为B
(batch size)、L
(序列长度)和C
(特征通道数),并假设L
应该等于H * W
,其中H
和W
表示高度和宽度。 - 将
x
重新调整形状为(B, H, W, C)
,即将特征向量重新组织为二维图像的形式。 - 检查输入图像的高度和宽度是否为奇数,如果是则在最后一行或最后一列进行填充操作。
- 对输入图像进行 2x2 的平均池化操作,分别得到四个子区域
x0
、x1
、x2
、x3
,然后将这四个子区域按特征通道拼接在一起。 - 将拼接后的结果重新调整形状为
(B, H/2 * W/2, 4 * C)
,即每个像素点被表示为原来的四倍特征向量。 - 将调整后的张量
x
分别经过norm
和reduction
模块的处理,其中reduction是一个全连接层以调整维度。 - 返回处理后的结果
x
。
class PatchMerging(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
4,总结
Swin Transformer的提出解决了之前Transformer应用在CV领域的缺点,具有以下优势:
- 分级特征表示:Swin-T引入分级特征表示,将图像按层次进行划分,允许底层特征捕获更局部、细节的信息,高层特征则能够捕获更全局、抽象的信息。这种分级特征表示的设计使得模型能够更好地处理不同尺度的特征。
- 低计算成本:相对于传统的全局自注意力机制,Swin-T的窗口注意力机制降低了计算复杂度,使得模型更加高效。此外,Swin-T还使用了相对位置编码和间距实现的移位操作,进一步提升了计算效率。
- 泛化能力:Swin-T在大规模图像分类任务上进行了训练,并在多个数据集上进行了评估,展示了良好的泛化能力和通用性。与其他视觉Transformer模型相比,Swin-T在许多图像分类任务上取得了更好的性能。
5.论文详细阅读
Abstract
本文介绍了一种名为 Swin Transformer 的新型视觉转换器,它可以作为计算机视觉的通用骨干。将 Transformer 从语言应用到视觉领域所面临的挑战来自这两个领域的差异,例如视觉实体的尺度变化很大,以及与文本中的文字相比,图像中像素的分辨率较高。为了解决这些差异,我们提出了一种分层变换器,其表示是通过移位窗口计算的。这种分层架构可以灵活地按不同尺度建模,其计算复杂度与图像大小成线性关系。Swin Trans former 的这些特性使其能够兼容广泛的视觉任务,包括图像分类(ImageNet-1K 上的最高准确率为 87.3)和密集预测任务,如物体检测(COCO 测试 dev 上的 58.7 box AP 和 51.1 mask AP)和语义分割(ADE20K val 上的 53.5 mIoU)。它的性能大大超过了以前的技术水平,在 COCO 上达到了 +2.7 box AP 和 +2.6 mask AP,在 ADE20K 上达到了 +3.2mIoU,这证明了基于 Transformer 的模型作为视觉骨干的潜力。 层次化设计和移动窗口方法也证明了对全 MLP 架构的益处。代码和模型可在 https://github. com/microsoft/Swin-Transformer 上公开获取。
1. Introduction
长期以来,计算机视觉建模一直由卷积神经网络(CNN)主导。从 AlexNet [39] 及其在 ImageNet 图像分类挑战中的革命性表现开始,CNN 架构通过更大的规模 [30, 76]、更广泛的连接 [34] 和更复杂的卷积形式 [70, 18, 84],变得越来越强大。随着 CNN 成为各种视觉任务的骨干网络,这些架构上的进步带来了性能上的提升,广泛地推动了整个领域的发展。
另一方面,网络架构在自然语言处理(NLP)领域的发展却走了一条不同的道路,如今流行的架构是 Transformer [64]。Transformer 专为序列建模和转译任务而设计,其显著特点是利用注意力对数据中的长程依赖关系进行建模。 它在语言领域取得的巨大成功促使研究人员开始研究其在计算机视觉领域的应用,最近它在一些特定任务上取得了令人鼓舞的成果,特别是图像分类 [20] 和视觉语言联合建模 [47]。
在本文中,我们试图扩大 Transformer 的适用范围,使其成为计算机视觉的通用骨干,就像它在 NLP 和 CNN 在视觉中的作用一样。我们发现,将 Transformer 在语言领域的高性能应用到视觉领域会遇到很大的困难,原因在于两种模式之间存在差异。这些差异之一涉及比例。与语言转换器中作为基本处理元素的文字标记不同,视觉元素在比例上会有很大差异,这个问题在物体保护等任务中受到关注 [42, 53, 54]。在现有的基于变换器的模式中 [64, 20],标记都是固定比例的,这一特性并不适合这些视觉应用。另一个不同之处在于,图像中像素的分辨率远高于文本中单词的分辨率。现在有很多诸如语义分割之类的视觉任务都需要像素级的密集预测,而这对于高分辨率图像上的 Transformer 来说是难以实现的,因为其自我关注的计算复杂度是图像大小的二次方。为了克服这些问题,我们提出了一种名为 Swin Transformer 的通用变换器主干,它可以构建分层特征图,其计算复杂度与图像大小成线性关系。如图 1(a)所示,Swin Transformer 从小块图斑(灰色轮廓)开始,逐步合并较深 Transformer 层中的相邻图斑,从而构建分层再现图。有了这些分层特征图,Swin Transformer 模型就能方便地利用先进的密集预测技术,如特征 pyra mid 网络(FPN)[42] 或 U-Net [51]。线性计算复杂度是通过在分割图像的非重叠窗口(红色轮廓)内局部计算自我关注来实现的。每个窗口中的斑块数量是固定的,因此复杂度与图像大小成线性关系。这些优点使 Swin Transformer 适合作为各种视觉任务的通用骨干,与之前基于 Transformer 的架构 [20] 形成鲜明对比,后者生成的特征图分辨率单一,复杂度为二次方。
如图 2 所示,Swin Transformer 的一个关键设计元素是在连续的自注意层之间移动窗口分区。移动后的窗口连接了前一层的窗口,提供了它们之间的连接,大大提高了建模能力(见表 4)。这种策略在实际延迟方面也很有效:一个窗口内的所有查询补丁都共享相同的密钥集1,这有利于硬件内存的获取。相比之下,早期基于滑动窗口的自关注方法 [33, 50],由于不同的查询像素有不同的密钥集,因此在一般硬件上延迟较低。我们的实验表明,与滑动窗口法相比,所提出的移位窗口法的延迟时间要短得多,但建模能力却相差无几(见表 5 和表 6)。事实证明,移位窗口法也有利于全 MLP 架构 [61]。
所提出的 Swin 变换器在图像分类、物体检测和语义分割等识别任务中表现出色。在这三项任务中,它以相似的延迟显著优于 ViT / DeiT [20, 63] 和 ResNe(X)t 模型 [30, 70]。在 COCO 测试开发集上,其 58.7 的盒 AP 和 51.1 的掩码 AP 超过了之前最先进结果的 +2.7 盒 AP(Copy-paste [26] without external data)和 +2.6 掩码 AP(DetectoRS [46])。在 ADE20K 语义分割方面,它在 val 集上获得了 53.5 mIoU,比之前的先进水平(SETR [81])提高了 +3.2 mIoU。在 ImageNet-1K 图像分类中,它的准确率也达到了 87.3%。
我们相信,一个横跨计算机视觉和自然语言处理的统一架构将使这两个领域受益匪浅,因为它将促进对视觉和文本信号的联合建模,并能更深入地共享这两个领域的建模知识。我们希望 Swin Transformer 在各种视觉问题上的出色表现能让这一信念深入人心,并鼓励对视觉和语言信号进行统一建模。
2. Related Work
3. Method
3.1. Overall Architecture
图 3 是 Swin T 变换器架构的预览,展示了其微小版本(Swin T)。它首先通过一个像 ViT 一样的补丁分割模块将输入的 RGB 图像分割成不重叠的补丁。每个补丁被视为一个 “标记”,其特征被设置为原始像素 RGB 值的连接。在实施过程中,我们使用的补丁大小为 4 * 4,因此每个补丁的特征维度为 4 * 4 * 3 = 48。对原始值特征采用线性嵌入层,将其投影到任意维度(记为 C)。
在这些补丁标记上应用几个经过修改的自注意计算变换器块(Swin 变换器块)。变换器模块保持标记的数量(H/4 * W/4),并与线性嵌入一起称为 “阶段 1”。
为了产生分层表示,随着网络的深入,通过补丁合并层来减少标记的数量。第一个补丁合并层将每组 2*2 个相邻补丁的特征串联起来,并在 4C 维串联特征上应用线性层。这就将标记数减少了 2 * 2=4 的倍数(分辨率降低 2 倍),并将输出维度设置为 2C。随后应用 Swin 变换器块进行特征变换,分辨率保持在 H 8 ∗ W 8 \frac{H}{8} *\frac{W}{8} 8H∗8W。这第一块补丁合并和特征变换被称为 “阶段 2”。该过程重复两次,分别为 "阶段 3 "和 “阶段 4”,输出分辨率分别为 H 16 ∗ W 16 \frac{H}{16} *\frac{W}{16} 16H∗16W 和 H H 32 ∗ W 32 \frac{H}{32} *\frac{W}{32} 32H∗32W。这些阶段共同产生分层表示,其特征图分辨率与典型卷积网络(如 VGG [52] 和 ResNet [30])相同。因此,所提出的架构可以方便地将现有方法中的骨干网络重新安置在不同的视觉任务中。
Swin Transformer block Swin Transformer 是通过将 Transformer 模块中的标准多头自我关注(MSA)模块替换为基于移位窗口的模块(见第 3.2 节)而构建的,其他层级保持不变。如图 3(b) 所示,Swin 变换器模块由一个基于移位窗口的 MSA 模块和一个 2 层 MLP 模块组成,MLP 模块之间有 GELU 非线性。每个 MSA 模块和每个 MLP 之前都有一个 LN 层,每个模块之后都有一个残差连接。
3.2. Shifted Window based Self-Attention
标准变换器架构[64]及其用于图像分类的调整[20]都是全局自我关注,即计算一个标记与所有其他标记之间的关系。全局计算会导致与标记数量相关的二次复杂性,因此不适合许多需要大量标记集进行密集预测或表示高分辨率图像的视觉问题。
Self-attention in non-overlapped windows 为了高效建模,我们建议在低视窗内计算自我注意力。这些窗口以不重叠的方式均匀分割图像。假设每个窗口包含
M
∗
M
M*M
M∗M个补丁,则全局 MSA 模块和基于窗口的模块在包含
h
∗
w
h*w
h∗w 个补丁的图像上的计算复杂度分别为
其中,前者与补丁数 hw 成二次方,而后者在 M 固定(默认设置为 7)时呈线性。对于较大的 hw,全局自我关注计算通常难以承受,而基于窗口的自我关注计算则可以扩展。
Shifted window partitioning in successive blocks 基于窗口的自我关注模块缺乏跨窗口连接,这限制了其建模能力。为了引入跨窗口连接,同时保持非重叠窗口的高效计算,我们提出了一种移动窗口分区方法,在连续的 Swin Transformer 块中交替使用两种分区配置。
如图 2 所示,第一个模块采用常规窗口分割策略,从左上角像素开始,将
8
∗
8
8*8
8∗8个特征图平均分割成大小为
4
∗
4
4*4
4∗4(M =4)的
2
∗
2
2*2
2∗2个窗口。然后,下一个模块采用的窗口配置与上一层的窗口配置相比发生了偏移,将窗口从规则分割的窗口中移出
⌊
M
2
,
M
2
⌋
\lfloor\frac{M}{2},\frac{M}{2}\rfloor
⌊2M,2M⌋个像素。 采用偏移窗口分割方法,连续 Swin 变换器块的计算公式为
其中,
z
ˉ
l
\bar{z}^l
zˉl 和
z
l
z^l
zl 分别表示块l 的 (S)W-MSA 模块和MLP模块的输出特性;W-MSA和SW-MSA分别表示使用常规和移动窗口分区配置的基于窗口的多头自注意。
如表 4 所示,移位窗口分割方法在前一层中引入了相邻非重叠窗口之间的连接,在图像分类、物体检测和语义分割方面效果显著。
Efficient batch computation for shifted configuration 移位窗口分区的一个问题是,它将导致更多的窗口,从移位配置中的 ⌈ h M ⌉ ∗ ⌈ w M ⌉ \lceil\frac{h}{M}\rceil*\lceil\frac{w}{M}\rceil ⌈Mh⌉∗⌈Mw⌉到 ( ⌈ h M ⌉ + 1 ) ∗ ( ⌈ w M ⌉ + 1 ) (\lceil\frac{h}{M}\rceil+1)*(\lceil\frac{w}{M}\rceil+1) (⌈Mh⌉+1)∗(⌈Mw⌉+1),其中一些窗口将小于 M ∗ M M*M M∗M。一个简单的解决方案是将较小的窗口填充到 M ∗ M M*M M∗M 的大小,并在计算注意力时屏蔽掉填充值。当常规分区中的窗口数较少时,例如 2 ∗ 2 2*2 2∗2个,采用这种天真的解决方案所增加的计算量相当可观( 2 ∗ 2 2*2 2∗2 -> 3 ∗ 3 3*3 3∗3,是原来的 2.25 倍)。在此,我们提出了一种更高效的批处理计算方法,即向左上方循环移动,如图 4 所示。循环移动后,一个批处理窗口可能由多个子窗口组成,而这些子窗口在特征图中并不相邻,因此我们采用了一种屏蔽机制,将自我关注计算限制在每个子窗口内。 循环移动后,批处理窗口的数量与常规窗口划分的数量相同,因此也是高效的。这种方法的低延迟如表 5 所示。
Relative position bias 在计算自我关注度时,我们遵循文献[49, 1, 32, 33]的方法,在计算相似度时,在每个头部加入相对位置偏差
B
∈
R
M
2
M
2
B\in R^{M^2 M^2}
B∈RM2M2:
其中,
Q
,
K
,
V
∈
R
M
2
×
d
Q,K,V \in R^{M^2\times d}
Q,K,V∈RM2×d是查询、键和值矩阵;
d
d
d 是查询/键维度,
M
2
M^2
M2 是窗口中的斑块数。由于每个轴的相对位置都在
⌊
−
M
+
1
,
M
−
1
⌋
\lfloor-M+1,M-1 \rfloor
⌊−M+1,M−1⌋的范围内,因此我们参数化了一个较小的偏差矩阵
B
^
∈
R
(
2
M
−
1
)
×
(
2
M
−
1
)
\hat{B}\in R^{(2M-1)\times (2M-1)}
B^∈R(2M−1)×(2M−1),
B
B
B 中的值取自
B
^
\hat{B}
B^。
如表 4 所示,与不使用该偏置项或使用绝对位置嵌入的同行相比,我们观察到了明显的改进。如文献[20]所述,在输入中进一步添加绝对位置嵌入会略微降低性能,因此我们在实现中没有采用。
在预训练中学习到的相对位置偏置也可用于初始化模型,以便通过双立方插值以不同的窗口大小进行微调[20, 63] 。
3.3. Architecture Variants
我们建立的基础模型名为 Swin-B,其模型大小和计算复杂度与 ViT B/DeiT-B 相似。我们还引入了 Swin-T、Swin-S 和 Swin-L,它们的模型大小和计算复杂度分别约为
0.25
×
0.25\times
0.25×、
0.5
×
0.5\times
0.5× 和
2
×
2\times
2×。请注意,Swin-T 和 Swin-S 的复杂度分别与 ResNet-50 (DeiT-S) 和 ResNet-101 相似。窗口大小默认设置为 M = 7。在所有实验中,每个头部的查询维数为 d = 32,每个 MLP 的扩展层为
α
\alpha
α= 4。这些模型变体的架构超参数为
其中 C 是第一阶段隐藏层的通道数。表 1 列出了 ImageNet 图像分类模型的大小、理论计算复杂度(FLOPs)和吞吐量。
4. Experiments
我们对 ImageNet-1K 图像分类 [19]、COCO 对象检测 [43] 和 ADE20K 语义分割 [83] 进行了实验。在下文中,我们首先将所提出的 Swin Transformer 架构与这三个任务的前沿技术进行比较。然后,我们将介绍 Swin Transformer 的重要设计元素。
5. Conclusion
本文介绍了一种新的视觉变换器 Swin Transformer,它能产生分层特征表示,并且计算复杂度与输入图像大小呈线性关系。Swin Transformer 在 COCO 物体检测和 ADE20K 语义分割方面达到了最先进的性能,大大超过了以前的最佳方法。我们希望 Swin Transformer 在各种视觉问题上的优异表现能推动视觉和语言信号统一建模的发展。
作为 Swin Transformer 的关键要素,基于移位赢道的自我关注已被证明对视觉问题有效且高效,我们期待着对其在自然语言处理中的应用进行研究。