Transformer ( 常规 Swin Vision) 的理解
首先简单介绍transformer和vision transformer
Transformer
这里主要讲的是multihead attention
维度计算:Transformer详解:各个特征维度分析推导 | Hello World 💓 (troublemeeter.github.io)
m: Embedding size(maxlen1)就是设置维度,也就是我把文本输进去之后,每个单词 / 字都是一个token,是模型输入基本单元,将token转化成计算机所能理解的序列, 映射的结果就是该token对应的embedding,它的长度一般设置成512。该操作是为了统一维度,如果长度不够padding = 0
d是 d m o d e l d_{model} dmodel, 是位置向量的维度, 输出后为d
h是多头的头数( head )
input: [batchsize, maxlen, d]
, 即
X
∈
R
m
×
d
X \in \mathbb{R}^{m \times d}
X∈Rm×d
-
Q, K, V含义
-
Q , K , V ∈ R m × d Q, K, V \in \mathbb {R}^{m \times d} Q,K,V∈Rm×d
-
严格上讲, `
-
: [maxlen_q, d] K = V : [maxlen_k, d]
-
但是我也不知道为什么在transformer中要强调一下QKV的maxlen不一样, 但是还要说Q = K = V, 因为基本上QKV维度都是一样的, 包括在transformer本身也是如此
-
-
Q和K严格来说是用来计算字符的相似度的, 正常情况下是形成一个下三角矩阵(CLIP), 也就是说,两个向量点乘表示两个向量的相似度, 或者说是attention score. 实际上,QKV在物理意义上是一致的,都是同一个句子不同token组成的矩阵, 矩阵的每一行都是一个token的embedding. 假设一个句子"Hello, how are you?"长度是6,embedding维度是300,那么Q,K,V都是(6, 300)的矩阵
-
如果QK相同,下三件矩阵泛化能力差, 因为向量倍投影到同一个空间. 这种矩阵在对V提纯的时候 ( Q K T QK^T QKT)结果不好.W
-
QKV不同, embedding矩阵被投影到不同的空间中. 泛化能力好
-
其实我自己的解释比较奇怪, 我认为第一个矩阵对第二个矩阵做矩阵乘法,某种程度上完成了单词前后之间的关系计算, 单纯的三角矩阵中上三角和下三角都是一致的, 可能做不到? 我也不是很清楚
-
首先, 对QKV进行线性变换, 最开始的QKV都是
X
∈
R
B
×
m
×
d
X \in\mathbb{R}^{\mathcal{B} \times m \times d}
X∈RB×m×d, d_h = d / h
, 还需要考虑多头注意力机制.
Q
i
∗
=
Q
W
Q
,
W
i
Q
∈
R
d
×
d
h
,
Q
i
∗
∈
R
B
×
m
×
d
h
K
i
∗
=
K
W
K
,
W
i
K
∈
R
d
×
d
h
,
K
i
∗
∈
R
B
×
m
×
d
h
V
i
∗
=
V
W
V
,
W
i
V
∈
R
d
×
d
h
,
V
i
∗
∈
R
B
×
m
×
d
h
Q^*_i = QW_Q, \quad W_i^Q \in \mathbb{R}^{d \times d_h}, \quad Q^*_i \in \mathbb{R}^{\mathcal{B} \times m \times d_h}\\ K^*_i = KW_K, \quad W_i^K \in \mathbb{R}^{d \times d_h}, \quad K^*_i \in \mathbb{R}^{\mathcal{B} \times m \times d_h} \\ V^*_i = VW_V, \quad W_i^V \in \mathbb{R}^{d \times d_h}, \quad V^*_i \in \mathbb{R}^{\mathcal{B} \times m \times d_h} \\
Qi∗=QWQ,WiQ∈Rd×dh,Qi∗∈RB×m×dhKi∗=KWK,WiK∈Rd×dh,Ki∗∈RB×m×dhVi∗=VWV,WiV∈Rd×dh,Vi∗∈RB×m×dh
对每个头(
head
i
\text{head}_i
headi)开始计算自注意力
f
=
Attention
(
Q
,
K
,
V
)
=
SoftMax
(
Q
K
T
d
k
)
V
f=\text{Attention}(Q, K, V)=\text{SoftMax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V
f=Attention(Q,K,V)=SoftMax(dkQKT)V
这里
d
k
=
512
d_k = 512
dk=512.
实际上, 应该是
f
=
Attention
(
q
,
K
,
V
)
=
SoftMax
(
q
K
T
d
k
)
V
f=\text{Attention}(q, K, V)=\text{SoftMax}\left(\frac{q K^T}{\sqrt{d_k}}\right) V
f=Attention(q,K,V)=SoftMax(dkqKT)V, 也就是说,对每一个q单独计算,然后重新组合. 所以这里用了矩阵计算
softmax计算的是字符对字符的相似性, 所以经过
Q
K
T
QK^T
QKT后得到的矩阵维度为m * m
, 得到相似性后, 与V矩阵相乘, 得到的维度为m * d
, 乘上V的目的是通过相似性(评分)让V中每个token的向量在每个维度上(每一列)上,都会对其他token做出调整
Vision Transformer
核心思想:分成patch后每个patch都是一个token
文本是二维的, 但是图像是三维的, 所以ViT把图片打成patch.
对于ViT来说, 把图片打成固定大小的patch, 然后flatten patch, 作为图像的token. 假设图像为224 * 224 * 3
, patch大小为16 * 16
, 那么该图像一共有14 * 14
个patch, 也就是196个token.
首先把patch打成token. patch展平就是16 * 16 * 3 = 768
. 我们需要把他通过线性映射变成想要的embedding size, 假设我们想要的是768, 那么既可以通过线性变换, 也可以加一个卷积核kernel_size = 16 * 16, stride = 16, channel = 768
, 总之最后变成了196 * 768
( 卷积公式: (224 - 16 + 2 * 0) / 16 + 1 = 14
, 变成[3, 224, 224] -> [758, 14, 14]
)
接下来就是输入到多头自注意力机制.
后面计算softmax 那么这里的seq_len * seq_len = 196(14 * 14) * 196
❓问题就是, 图像分辨率增大, patch大小不变, 然而, token数量变多了,如果是448 * 448
, 那么patch数量为28 * 28
, 这个增长数量还是蛮多的, 平方级复杂度.
所以, 提出来了swin transformer
- 计算复杂度
-
首先是QKV三次线性映射
- 打成patch的结果
Q, K, V: [h*w, d(c)] = [14*14, 768]
与权重W: [d(c), d_h]
, 则复杂度为3 * h * w * c * c
, ->[h*w, d_h]
- 打成patch的结果
-
softmax
-
Q
K
T
Q K^T
QKT则
[h*w, d_h], [d_h, h*w]
, 复杂度为h * w * h * w * c
, ->[h*w, h*w]
-
Q
K
T
×
V
QK^T \times V
QKT×V, 则
[h*w, h*w], [h*w, d_h(c)]
, 复杂度为h * w * h * w * c
, ->[h*w, d_h(c)]
-
Q
K
T
Q K^T
QKT则
-
线性映射
- softmax输出结果为
[h*w, d_h(c)]
, 需要经过一个w: [d_h(c), d]
的线性变换, 则复杂度为h * w * c * c
- softmax输出结果为
-
则复杂度为
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(\mathrm{MSA})=4 h w C^2+2(h w)^2 C Ω(MSA)=4hwC2+2(hw)2C
-
Swin Transformer
整体介绍
核心思想: 分成多个window后变成单独图片, 对每个window计算自注意力
为了解决ViT弊端, 提出了window概念.
把图片分成8 * 8
个窗口, 每个window里面有7 * 7
个patch, patch大小为4 * 4
, 这样的话, 多个图像的多个window做mult-head attention, 后面计算softmax的seq_len * seq_len = 49(7 * 7) * 49
.
值得注意的是, swin transformer的定义窗口的时候就是窗口大小定义. 假设窗口大小为M = , 则窗口数量为H / M * W / M
个窗口.
- 计算复杂度
-
对于单个窗口计算复杂度为:
Ω ( M S A ) = 4 m m C 2 + 2 ( m m ) 2 C \Omega(\mathrm{MSA})=4 mm C^2+2(mm)^2 C Ω(MSA)=4mmC2+2(mm)2C -
那么乘上窗口数量
Ω ( W − M S A ) = ( 4 m m C 2 + 2 ( m m ) 2 C ) ∗ H W / m m = 4 h w C 2 + 2 M 2 h w C \begin{aligned} \Omega(\mathrm{W-MSA})&=(4 mm C^2+2(mm)^2 C) * HW / mm \\ & = 4hwC^2 +2M^2hwC \end{aligned} Ω(W−MSA)=(4mmC2+2(mm)2C)∗HW/mm=4hwC2+2M2hwC
-
代码学习
世界名画
一整个swin transformer模块
基本遵循世界名画的顺序搭建出来的
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
# 初始化
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, fused_window_process=False, **kwargs):
# num_heads=[3, 6, 12, 24]这里头数一直都×2, 是因为patch merging操作每次都让embedding乘2
# depth 2 2 6 2 都是偶数
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
# patch embedding让224*224*3的图像变成56*56*96的图像(通过卷积搞的)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
# 是否加绝对位置编码(类似于vision transformer或者传统transformer)
# 这里加不加都行
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate) # dropout
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# 这是等间隔地设置dropout rate
# 包含了从 0 到 drop_path_rate 的等间隔的数字,这个等间隔的区间被分成了 sum(depths) 个部分
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), # 这里的embedding确实不断 * 2
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process)
# 以上layer构造四个阶段
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1) # 一个池化
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # 加上前面的池化就是构造多分类
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x = self.patch_embed(x) # 首先是patch embedding
if self.ape: # 看是否加绝对位置编码
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x) # 接受的是forward_features
x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
patch partition & Linear embedding
这一步不如说降采样, 把H * W * 3
变成H / 4, W / 4, 16C
. 这一步是通过kernel_size = 4*4, stride = 4
的卷积核得到的. 看文章patch不是一块一块的吗, 都是4 * 4
大小的, 这里就是patch大小可以当成1*1
, 也就是一个像素大小. 因为图像被缩小以后4*4
的都堆叠在通道维度上. 所以严格来讲是对1*1
的patch转化为embedding.
而且, patch partition和linear embedding都放在patch embedding当中, 通过一次卷积的映射. 也就是把224 * 224 * 3
的转化56 * 56 * (48 ->)96
class PatchEmbed(nn.Module):
# 这一块是最初的downsample
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size) # 变成元组
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# 这一块生成的是(H/4, W/4, embedding=96(3*4*4))
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
# 这一块self.norm = norm_layer(embed_dim)
# 创建一个归一化层,其输入特征的维度为embed_dim,然后将其赋值给类的self.norm属性。
# 这样,在网络的前向传播过程中,可以使用self.norm来对嵌入的特征进行归一化处理
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# 看图像尺寸是否与预期图像一致
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
# 上面是卷积核的映射,就是沿着第三个维度展平,56*56=3136
if self.norm is not None:
x = self.norm(x)
# 做一个归一化
return x
def flops(self):
# 计算浮点数
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
# flops = 56 * 56 * 96 * 3 * (4 * 4)
# 常规卷积核的FLOPs计算:(这个是输出特征图每个像素的)(2 * k_h * k_w * c_in) * (这个是特征图整体尺寸的)(c_out * h_o * w_o) (卷积过程涉及先乘后加)
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
swin transformer block
看世界名画的successive 图可以看到, 先经过的是layernorm, 然后开始做自注意力. 这里自注意力机制作者搞了个windowAttention类来计算, 推荐看这个视频的讲解BV1bq4y1r75w. 不同于transformer和vision transformer, swin在窗口注意力中搞了个相对位置编码, 这一块不太好理解.
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
# shift_size大于等于0;shift_size小于window_size。需要满足这两个条件,有一个不满足都会assertion error
self.norm1 = norm_layer(dim)
# 最重要的部分就是attn
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# 似乎是生成mask矩阵?
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# 然后是mask矩阵
if self.shift_size > 0:
# 大于0表示真的移动窗口了
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 只有HW的这么一个张量,没有批次和通道,单张图片
h_slices = (slice(0, -self.window_size), # 加入window_size = 7,也就是7个patch,也就是0到-7
slice(-self.window_size, -self.shift_size), # 这里是切出来-7 -3
slice(-self.shift_size, None)) # 这里是切出来 -3 none
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 这是给单张的image_mask给划分出来
# 这里都是初始化
mask_windows = window_partition(img_mask, self.window_size) # nW * B, window_size, window_size, 1
# mask_windows(num_windows*B, window_size, window_size, C=1)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
# 这个矩阵变成了二维矩阵,把第0维看成数量,第1维堪称我们需要的
# 这里的矩阵其实是相对位置矩阵,由于移动后的矩阵中包含不相关的窗口,把相同窗口的位置索引赋予相同的索引值, windowAttention计算的
# 把相对位置矩阵非拉长变成一条,横着一条竖着一条,可以计算索引与索引之间的差值
# 如果插值都等于0说明来源于同一个窗口,不为0说明这patch像素不存在相关性
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
# unsqueeze就是二维的基础上加一维,便于广播,就是横着一条有4个,竖着一条有4个,广播后就是4*4=16
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
# 这里就是如果为0则为0,不为0则为负无穷,就是mask
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
# 以上过程就是初始化, self.norm1, self.attn都是预先设置好, 方便后面用.
# 在前面定义好mask, 也方便后面用
# forward阶段, 就开始按部就班的把transformer的每个部分实现
# 调用init中提出过的函数啥的
def forward(self, x):
H, W = self.input_resolution # x: 1, 3136, 96 B = 1
B, L, C = x.shape # L = 56*56=3136
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x) # 层归一化
x = x.view(B, H, W, C) # -> 1, 56, 56, 96
# cyclic shift
# 这里就是如果有移动窗口的话, 就需要把像素移动一下, 沿着宽和高的方向移动, 然后打成patch
if self.shift_size > 0: # 窗口移动
if not self.fused_window_process:
# x: B, H, W, C
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# shifts=(-self.shift_size, -self.shift_size)是在HW两个维度上移动
# # 这个roll函数很有意思, 就是用来把前面像素移动到后面
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# 我明白了, 就是前面初始化把attn_mask只是临时计算了一下, 里面的值都是0. 如果shift大于0, 那么mask就能用
# 然后调用windowattn, 在这里面计算相对位置索引, 然后计算偏置.
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# 图像变回原来的形状
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# window_reverse让窗口变回原来的形状(B H' W' C), 并且让移位的图像也变回去
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
# 如果没有移位, 那么就是window_reverse以后直接就是了
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x) # 残差连接, 还要对前面的x做一个dropout
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# 所以说果然是window_size * window_size, 对应前面attn计算flpops的hw * hw
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
首先是W-MSA和SW-MSA, 我们先把windowAttention当成一个黑盒子
-
在初始化阶段, 先引入这个self.attn, 后面是dropout, layernorm和MLP, 初始化部分开始搞mask矩阵, 这个mask矩阵根据shift_size来判断
-
如果size大于0说明我们真的移动窗口了, 也就是SW-MSA.
-
这个时候我们需要构造一个mask矩阵, 为什么要构造呢, 是因为我们对图像做了移动, 把部分左侧像素点和上方像素点分别移动到了右侧和下侧, 这个移动过程是看forward. 这样的话, 移动后的像素不久和周围没啥相关性吗? 所以要把没相关性的异端像素弄成-inf, 这样防止做自注意力机制的时候把异端和原住民给混一块
-
具体来说, 如果我给这个图划分成好几个小窗口, 我们可以看到右边和下边的几个窗口明显是有异端的, 那我得给人家标号对吧, 标号后有什么作用我后面会提到.
-
这个mask矩阵我们就当成
(1, H, W, 1)
大小, 没有批没有通道的一张图片,56 * 56
的. 我们怎么找到没有相关性的像素呢? 答案是把移动后的mask矩阵给标号, 打个比方说, 给图片划分国家, 你是你我是我. mask矩阵处理过程看代码注释 -
为什么切出来
(-7, -3), (-3, None)
, 这是因为每次移动都是移动3个像素到最后面, 但是我们知道, 窗口大小是7*7
, 那就糟糕了, 一个窗口里混出3列或者3排异端(看上面的图), 那我就要区分出来啊 -
h_slices = (slice(0, -self.window_size), # 加入window_size = 7,也就是7个patch,也就是0到-7 slice(-self.window_size, -self.shift_size), # 这里是切出来-7 -3 slice(-self.shift_size, None)) # 这里是切出来 -3 none w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1
-
区分后, 我以右下角的窗口为例, 右下角窗口图像鱼龙混杂. 首先通过window partition函数, 这个函数就是把图像打成一个个小window, 维度是
(B, H, W, C) -> (B, H // window_size(nWindow'), window_size, W // window_size(nWindow'), window_size, C) -> (B, nWindow', nWindow', window_size, window_size, C) -> (B*nWindow, window_size, window_size)
. 这种小window便于做窗口自注意力机制. -
def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows # B, H // window_size, window_size, W // window_size, window_size, C # -> (num_windows*B, window_size, window_size, C) # 其中,num_windows = (H // window_size) * (W // window_size) # 也就是打成window,window在第0个维度上组成(这个过程非常像patch)
-
下面是生成窗口mask的代码, 也就是
attn_mask
, 维度变化(B*nWindow, window_size, window_size) -> (B*nWindow, window_size*window_size)
, 把这些小窗口的mask变成二维以后, 把他给展开, 如下图所示(DASOU这个视频 BV1zT4y197Fe 非常建议看一下, 看完之后再把代码捋一遍基本就通了, 不过前提是看了transformer和Vision transformer). 我们前面不是给mask矩阵标号了吗, 一个国家一个标号, 假设现在是2*2
的窗口, 把这个窗口拉直, 然后计算每个元素之间的差值, 如果相差为0说明他们是一个区域的元素, 如果不为0就说明不是一个区域的, 所以没有相关性,需要把他们盖住, 驱逐他们! (就是弄成负无穷, 这样在后面计算自注意力的softmax的时候, 它们的值无限小, 也不会影响结果, 也不会影响到相关性). 这个拉长的过程用了unsqueeze
函数, 自行百度unsqueeze, 简单来说就是增加一个维度, 然后广播机制. -
# 这里都是初始化 mask_windows = window_partition(img_mask, self.window_size) # nW * B, window_size, window_size, 1 # mask_windows(num_windows*B, window_size, window_size, C=1) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 这个矩阵变成了二维矩阵,把第0维看成数量,第1维堪称我们需要的 # 这里的矩阵其实是相对位置矩阵,由于移动后的矩阵中包含不相关的窗口,把相同窗口的位置索引赋予相同的索引值, windowAttention计算的 # 把相对位置矩阵非拉长变成一条,横着一条竖着一条,可以计算索引与索引之间的差值 # 如果插值都等于0说明来源于同一个窗口,不为0说明这patch像素不存在相关性 attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # unsqueeze就是二维的基础上加一维,便于广播,就是横着一条有4个,竖着一条有4个,广播后就是4*4=16 attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) # 这里就是如果为0则为0,不为0则为负无穷,就是mask
-
-
如果说没移动, 那就没有这么多b事了, 不需要
attn_mask
-
-
forward
-
前面把mask给整出来了, 我前向传播就要搞点实在的, 如果说需要移动窗口, 那我就要用
torch.roll
函数. 这个函数把某个维度的前几部分移动到后面, 看下面的示例就能看懂 -
import torch x = torch.arange(80).reshape((1, 4, 4, 5)) print(x) print('------------------------------------') # 使用 torch.roll() 进行循环移位 shifted_x = torch.roll(x, shifts=(-3, -2), dims=(1, 2)) print(shifted_x) 输出 tensor([[[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]], [[20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39]], [[40, 41, 42, 43, 44], [45, 46, 47, 48, 49], [50, 51, 52, 53, 54], [55, 56, 57, 58, 59]], [[60, 61, 62, 63, 64], [65, 66, 67, 68, 69], [70, 71, 72, 73, 74], [75, 76, 77, 78, 79]]]]) ------------------------------------ tensor([[[[70, 71, 72, 73, 74], [75, 76, 77, 78, 79], [60, 61, 62, 63, 64], [65, 66, 67, 68, 69]], [[10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9]], [[30, 31, 32, 33, 34], [35, 36, 37, 38, 39], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29]], [[50, 51, 52, 53, 54], [55, 56, 57, 58, 59], [40, 41, 42, 43, 44], [45, 46, 47, 48, 49]]]])
-
这个是图像移动的代码, 得到移动后的窗口
x_windows
, 维度为(# nW*B, window_size, window_size, C)
-
if self.shift_size > 0: # 窗口移动 if not self.fused_window_process: # x: B, H, W, C shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # shifts=(-self.shift_size, -self.shift_size)是在HW两个维度上移动 # # 这个roll函数很有意思, 就是用来把前面像素移动到后面 # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C else: x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
-
接下来把
x_windows
弄成维度(nW*B, window_size*window_size, C)
. 现在, 我们拥有移动好的窗口x_windows
, 还有在初始化部分搞得attn_mask
, 我们就可以做窗口自注意力了(通过windowAttention
, 这个函数在初始化部分已经声明过了self.attn()
). -
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # 我明白了, 就是前面初始化把attn_mask只是临时计算了一下, 里面的值都是0. 如果shift大于0, 那么mask就能用 # 然后调用windowattn, 在这里面计算相对位置索引, 然后计算偏置. # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 图像变回原来的形状
-
接下来还要把窗口还回来, 需要做一次
window_reverse
, 还原成原先的(B, H, W, C)
, 维度计算主要经历了(B, num_window, num_window, window_size, window_size, C) -> (B, num_window, window_size, num_window, window_size, C) -> (B, H, W, C)
-
def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # (B, num_window, num_window, window_size, window_size, C) # -> (B, num_window,, window_size, num_window, window_size, C) # -> (B, H, W, C) return x
- 此外, 如果shift过, 除了需要reverse, 还需要反过来再roll一遍,
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
, 前面是(-self.shift_size, -self.shift_size)
, 这里都是(self.shift_size, self.shift_size)
- 如果没有shift过那当然就不需要reverse啦
- 其实我自己有个疑问,
fused_window_process
真的飞机非加不可吗, 我看视频的时候, up主的源码版本就没有这玩意, 如果有大佬明白希望能解惑.
- 此外, 如果shift过, 除了需要reverse, 还需要反过来再roll一遍,
-
# reverse cyclic shift if self.shift_size > 0: if not self.fused_window_process: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # window_reverse让窗口变回原来的形状(B H' W' C), 并且让移位的图像也变回去 x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size) else: # 如果没有移位, 那么就是window_reverse以后直接就是了 shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) # 残差连接, 还要对前面的x做一个dropout
-
最后就是i计算flops之类的了
-
def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # 所以说果然是window_size * window_size, 对应前面attn计算flpops的hw * hw # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops
-
windowAttention
这部分算是精华了. 先上代码(代码里面有我的很多注释, 可以看一下):
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
# dim是输入维度,
# 这里面还有dropout,有意思的技术细节
# qk_scale就是dk
# num_heads随着后面patch merging增加通道数,也就是embedding维数,为了让每个头接受的数据量一致,所有增加num_head
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# 这里,如果传入dk也就是qk_sclae,那么就是根号dk;如果不传入,就是根号head_dim
# define a parameter table of relative position bias
# 相对位置参数
# 就是softmax(qk^t / dk)*V + b
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
# 这里是计算相对位置的索引
# 这个索引是个固定的,也就是不可学习的,所以最后放在register_buffer里
# 然而,相对位置偏表B是learnable,B的索引如果与相对位置矩阵中的值对应,就要把这个索引的值放在相对位置矩阵这个值上
# for i, value in enumerate(table):
# position[torch.where(position == i)] = value
coords_h = torch.arange(self.window_size[0]) #
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
# 相对位置编码以后的常规操作
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# nn.Linear(输入特征, 输出特征,是否加入偏置)
# 输出特征*3是为了要输出查询、键和值三个张量
# 定义了一个线性映射层,
self.attn_drop = nn.Dropout(attn_drop)
# dropout的作用是随机失活
self.proj = nn.Linear(dim, dim)
# 映射成原来的特征大小
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
# 以上都是初始化过程
# 向前传播
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
# 下面线性变换的时候是在通道上堆叠QKV
# N 表示序列长度(sequence length)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # permute重排列
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # 把最后两个维度颠倒过来
# patch数量也就是window大小 * window大小, window大小 * window大小
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
# # TODO: 可不可以这么写代码
# # 我可不可以
# for i, value in enumerate(self.relative_position_bias_table):
# self.relative_position_index[torch.where(self.relative_position_index == i)] = value
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
# mask 维度: num_windows, Wh*Ww, Wh*Ww (Wh = seq_len * seq_len)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
# mask维度为[1, nW, 1, Mh*Mw, Mh*Mw]
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N, token length应该就是7*7=49这种的, 然后计算49*49
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N # 其实就是计算复杂度那块, hw * hw * c
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads) # hw * hw * c
# x = self.proj(x)
flops += N * self.dim * self.dim # hw * c * c
return flops
-
初始化
- 初始化部分定义了一些参数, 最终的部分就是提出了相对位置编码, 搭配BV1bq4y1r75w和Swin Transformer之相对位置编码详解-CSDN博客食用, 链接里面有个图, 很有助于理解.
-
简单来说, 就是搞了
(2, window_size, window_size)
的图像, 一个按照行填充0123, 一个按照列填充, 然后拉直, 再增加维度后, 广播机制相减, 图里面的M是window_size
, 对行列标做一些计算以后, 把这两片张量相加, 我们会发现张量里面的数字范围在[2M-1, 2M+1]
, 然后我们初始化一个偏表, 偏表里面放了(2M-1) * (2M+1)
个数字, 矩阵的值作为索引, 找到偏表里面的值, 然后放进去. -
在初始化过程, 就是生成了相对位置索引, 也就是把这个张量沿着第0个维度相加, 我称之为两片小张量相加.
-
coords_h = torch.arange(self.window_size[0]) # coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index)
-
接下来就是一些常规操作, 比如说把QKV放在一起, 然后dropout, 再就是映射.
-
# 相对位置编码以后的常规操作 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # nn.Linear(输入特征, 输出特征,是否加入偏置) # 输出特征*3是为了要输出查询、键和值三个张量 # 定义了一个线性映射层, self.attn_drop = nn.Dropout(attn_drop) # dropout的作用是随机失活 self.proj = nn.Linear(dim, dim) # 映射成原来的特征大小 self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) # 以上都是初始化过程
-
forward
-
forward首先把计算 Q K T QK^T QKT, 这里面涉及到维度计算
(维度: num_windows*B, 序列长度(?是什么?), qkv三个, 头数, 每个头的维度) -> (qkv三个, num_windows*B, 头数, 序列长度(H*W), 每个头的维度)
-
# 向前传播 def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape # 下面线性变换的时候是在通道上堆叠QKV # N 表示序列长度(sequence length) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # permute重排列 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) # 把最后两个维度颠倒过来 # patch数量也就是window大小 * window大小, window大小 * window大小
-
接下来就i算relative_position_bias, 首先是通过
nn.Parameters
来生成learnable参数, 然后根据前面的索引张量, 就是说这个张量里面的值是索引, 在偏表里面找到对应的值, 填在表里面, 而这个表就是(window_size*window_size, window_size*window_size)
大小的(nH是头数). 然后在注意力窗口上加入这个相对位置偏置(通过unsqueeze添加一个维度来进行广播) -
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0)
-
接下来就需要计算softmax, 前面都是先计算相似度和相对位置, 接下来把mask加上去(如果没有就不加), mask维度调整为
(1, nW, 1, Mh*Mw, Mh*Mw)
, -
后面就是线性映射和dropout之类的了
-
if mask is not None: nW = mask.shape[0] # mask 维度: num_windows, Wh*Ww, Wh*Ww (Wh = seq_len * seq_len) attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # mask维度为[1, nW, 1, Mh*Mw, Mh*Mw] attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x
-
Patch Merging
这个部分就是用来降采样的, 先上代码
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
# 偶数个元素数量就没问题
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x) # 线性变换
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
这个没啥讲的, X输入维度是(B, H*W, C) -> (B, H, W, C)
就是每次都让HW少一半, C增加至原来的四倍. 具体地说, 就是隔一个取一个元素, 然后沿着通道维度拼起来
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
Basic layer
和世界名画不太一样, 这里是包括SwinTransformerBlock
和PatchMerging
两个, 把SW模块堆叠起来之后后面放一个patchmerging, oatchmerging就相当于downsample
SwinTransformer
最后就是按照世界名画的样子叠起来