Swin transformer详解

 


目录

摘要

一、介绍

二、原理

2.1 整体架构

2.1.1 Architecture 

2.1.2 Swin Transformer Block

2.2 基于移位窗口的自注意力

2.2.1 非重叠局部窗口中的自注意力

2.2.2 在连续块中的移位窗口划分      

2.2.3 便于移位配置的高效批计算

2.2.4 相对位置偏置

2.3 架构变体 

三、源码

3.1 Swin Transformer

3.2 Patch Embedding

3.3 Patch Merging

3.4 Window Partition

3.5 Window Reverse

3.6 MLP

3.7 Window Attention (W-MSA Module) ☆

3.8 Swin Transformer Block ☆

3.8.1 Shift Window Attention

3.8.2 Attention Mask

3.9 Basic Layer



摘要

        本文介绍了一种称为 Swin Transformer 的新视觉 Transformer,它可以作为 CV 的通用主干。将 Transformer 从语言适应到视觉方面的挑战来自 两个域之间的差异,例如视觉实体的规模以及相比于文本单词的高分辨率图像像素的巨大差异。为解决这些差异,我们提出了一种 层次化 (hierarchical) Transformer,其表示是用 移位窗口 (Shifted Windows) 计算的。移位窗口方案通过 将自注意力计算限制在不重叠的局部窗口的同时,还允许跨窗口连接来提高效率。这种分层架构具有在各种尺度上建模的灵活性,并且 相对于图像大小具有线性计算复杂度Swin Transformer 的这些特性使其与广泛的视觉任务兼容,包括图像分类(ImageNet-1K 的 87.3 top-1 Acc)和 密集预测任务,例如 目标检测(COCO test dev 的 58.7 box AP 和 51.1 mask AP)和语义分割(ADE20K val 的 53.5 mIoU)。它的性能在 COCO 上以 +2.7 box AP 和 +2.6 mask AP 以及在 ADE20K 上 +3.2 mIoU 的大幅度超越了 SOTA 技术,证明了基于 Transformer 的模型作为视觉主干的潜力。分层设计和移位窗口方法也证明了其对全 MLP 架构是有益的


一、介绍

        CV 建模一直由 CNN 主导。从 AlexNet 和它在图像分类挑战上的革命性性能开始,CNN 架构已通过更大规模的、更广泛的连接和更复杂的卷积形式变得越来越强大。随着 CNNs 作为各种视觉任务的主干网络,这些架构的进步促进了性能的提升,并广泛地带动了整个领域的发展。另一方面,在 NLP 中,网络架构的发展已采取了一条不同的道路,即时至今日流行的架构是 Transformer。为序列建模和转换任务而设计的 Transformer,因其注意力机制对数据中的长程依赖性进行建模而闻名。它在语言领域的巨大成功使研究人员研究了它对计算机视觉的适应性,最近它在某些任务上展示了良好的结果,特别是图像分类和联合视觉-语言建模。本文试图扩大 Transformer 的适用性,使它可以作为 CV 的通用主干,正如其在 NLP 和 CNNs 在 CV 中一样。我们注意到将其在语言领域的高性能迁移到视觉领域的显著挑战,而这可用 两种模态之间的差异 来解释。

        其中一种差异涉及尺度 (scale)。与在语言 Transformer 中作为处理的基本元素的 word token 不同,视觉元素在尺度 (scale) 上可以存在很大差异,这是一个在目标检测等任务中受到关注的问题。在现有的基于 Transformer 的模型中,token 的尺度 (scale) 都是固定的,这是一种不适合这些视觉应用的性质。另一个差异是,图像中的像素分辨率比文本段落中的文字要高得多。存在许多视觉任务 ,如语义分割,需在像素级别上进行密集预测,这对于高分辨率图像上的 Transformer 而言是难以处理的,因为其 自注意力的计算复杂度是关于图像大小的二次方

图 1

        为克服这些问题,Swin Transformer 构造了层次化特征图,且关于图像大小具有线性计算复杂度。如图 1 (a) 所示,Swin Transformer 通过 从小尺寸 patch (灰色轮廓) 开始,逐渐在更深的 Transformer 层中合并相邻 patch,从而构造出一个层次化表示 (hierarchical representation)。通过这些层次化特征图,Swin Transformer 模型可方便地利用先进技术进行密集预测,例如特征金字塔网络 (FPN) 或 U-Net。线性计算复杂度是通过在图像分区的非重叠窗口内,局部地计算自注意力来实现的 (红色轮廓) (而非在整张图像的所有 patch 上进行)。每个窗口中的 patch 数量是固定的,因此复杂度与图像大小成线性关系。这些优点使 Swin Transformer 适合作为各种视觉任务的通用主干,与之前基于 Transformer 的架构形成对比,后者产生单一分辨率的特征图并具有二次复杂度

图 2

        Swin Transformer 的一个关键设计元素是它 在连续自注意力层之间的窗口分区的移位 (shift),如图 2 所示。移位窗口桥接了前一层的窗口,提供二者之间的连接,显着增强建模能力 (见表 4)。这种策略对于现实世界的延迟也是有效的:一个局部窗口内的所有 query patch 共享相同的 key 集合,这有助于硬件中的内存访问。相比之下,早期的 基于滑动 (sliding) 窗口的自注意力方法 由于 不同 query 像素具有不同的 key 集合 而在通用硬件上受到低延迟的影响。我们的实验表明,所提出的移位窗口方法的延迟比滑动窗口方法低得多,而建模能力相似 (见表 5 / 6)。移位窗口方法也被证明对全 MLP 架构有益。 

        所提出的 Swin Transformer 在图像分类、目标检测和语义分割的识别任务上取得了强大的性能。它在三个任务上以相似的延迟显着优于 ViT / DeiT 和 ResNe(X)t 模型。 我们相信,跨 CV 和 NP 的统一架构可以使这两个领域受益,因为它将促进视觉和文本信号的联合建模,并且可以更深入地共享来自两个领域的建模知识。我们希望 Swin Transformer 在各种视觉问题上的强大表现能够在社区中更深入地推动这种信念,并鼓励视觉和语言信号的统一建模。


二、原理

2.1 整体架构

2.1.1 Architecture 

图 3

        图 3 展示了 Swin Transformer 架构概览 (tiny 版 SwinT)。它首先通过 Patch 拆分模块 (Patch Partition) (同 ViT) 将输入的 H×W×3𝐻×𝑊×3 的 RGB 图像拆分为非重叠等尺寸的 N×(P2×3)𝑁×(𝑃2×3) patch每个 P2×3𝑃2×3 patch 都被视为一个 patch token,共拆分出 N𝑁 个 (即 Transformer 的有效输入序列长度)

        更具体地,用 P2=4×4𝑃2=4×4 大小且通道数 C=3𝐶=3 的 patch,故各 patch 展平后的特征维度为 P×P×C=4×4×3=48𝑃×𝑃×𝐶=4×4×3=48,共有 N=H4×W4=HW16𝑁=𝐻4×𝑊4=𝐻𝑊16 个 patch tokens。换言之,每张 H×W×3𝐻×𝑊×3 的图片被处理为了 H4×W4𝐻4×𝑊4 个图片 patches,每个 patch 被展平为 4848 维的 token 向量 (类似 ViT 的 Flattened Patches),整体上是一个展平 (flatten) 的 N×(P2×3)=(H4×W4)×48𝑁×(𝑃2×3)=(𝐻4×𝑊4)×48 维 2D2𝐷 patch 序列。

        线性嵌入层 (Linear Embedding) (即全连接层) 则将此时维度为 (H4×W4)×48(𝐻4×𝑊4)×48 的张量投影到任意维度 C𝐶,得到维度为 (H4×W4)×C(𝐻4×𝑊4)×𝐶 的 Linear Embedding

        随后,这些 patch tokens (此时已为 Linear Embedding) 被馈入若干具有改进自注意力的 Swin Transformer blocks。首个 Swin Transformer block 保持输入输出 tokens 数恒为 H4×W4𝐻4×𝑊4 不变,且与 线性嵌入层 共同被指定为 Stage 1 (如图 3 的第一个虚线框所示)。

        为产生一个 层次化表示 (Hierarchical Representation),随着网络的加深,tokens 数逐渐通过 Patch 合并层 (Patch Meraging) 被减少。首个 Patch 合并层拼接了每组 2×22×2 相邻 patch,则 patch token 数变为原来的 1414,即 H8×W8𝐻8×𝑊8,而 patch token 的维度则扩大 44 倍,即 4C4𝐶。然后,对 4C4𝐶 维的 patch 拼接特征使用了一个线性层,将输出维度降为 2C2𝐶。然后使用 Swin Transformer blocks 进行特征转换,其分辨率保持 H8×W8𝐻8×𝑊8不变。首个 Patch 合并层 和 该特征转换 Swin Transformer block 被指定为 Stage 2 (如图 3 的第二个虚线框所示)。重复 2 次与 Stage 2 相同过程,则分别指定为 Stage 3 和 Stage 4 (如图 3 的第三、四个虚线框所示)。输出分辨率 /  patch token 数 则分别为 H16×W16𝐻16×𝑊16 和 H32×W32𝐻32×𝑊32。每个 Stage 都会改变张量的维度,从而形成一种层次化的表征。由此,该架构可方便地替换现有的各种视觉任务的主干网络。

2.1.2 Swin Transformer Block

W-MSA规则窗口 MSA    -    SW-MSA移位窗口 MSA

        Swin Transformer 相比于 Transformer block (例如 ViT),将 标准多头自注意力模块 (MSA) 替换为 基于移位窗口的多头自注意力模块 (W-MSA / SW-MSA) 且保持其他部分不变 (描述于 3.2 节)。如图 3(b) 或上图所示,一个 Swin Transformer block 由一个 基于移位窗口的 MSA 模块 构成,且后接一个夹有 GeLU 非线性在中间的 2 层 MLPLayerNorm (LN) 层被应用于每个 MSA 模块和每个 MLP 前,且一个 残差连接 被应用于每个模块后。


2.2 基于移位窗口的自注意力

        标准的 Transformer 架构及其对图像分类的适应版本都执行 全局自注意力,其计算了每个 token 与其他所有 tokens 之间的关系 (Attention Map)。全局自注意力计算 具有 相对于 token 数的二次计算复杂度 O(N2D)𝑂(𝑁2𝐷) (N𝑁 为 token 数 / 序列长度,D𝐷 为 token 向量长度 / 嵌入维度),使之不适用于许多需大量 tokens 的 密集预测 / 高分辨率图像表示 等 高计算量视觉问题。

         O(MSA) 或 O(MHA) 的计算:

         当 n>>d𝑛>>𝑑 时,O(MHA)=O(n2d)𝑂(𝑀𝐻𝐴)=𝑂(𝑛2𝑑),或者说 O(MSA)=O(N2D)𝑂(𝑀𝑆𝐴)=𝑂(𝑁2𝐷)

2.2.1 非重叠局部窗口中的自注意力

        为高效建模,我们提出 在非重叠的局部窗口中计算自注意力,取代全局自注意力。以不重叠的方式均匀地划分图像得到各个窗口。已知 D = 2C,则设 每个非重叠局部窗口都包含 N = M × M 个 patch tokens,则 基于具有 N = h × w 个 patch tokens 的图像窗口的 MSA 模块 和 基于非重叠局部窗口的 W-MSA 模块 的计算复杂度分别是:

        其中,MSA 关于 patch token 数 h×wℎ×𝑤 具有 二次复杂度 (共 hwℎ𝑤 个 patch tokens,每个 patch token 在全局计算 hwℎ𝑤 次)。W-MSA 则当 M𝑀 固定时 (默认设为 77) 具有 线性复杂度  (共 hwℎ𝑤 个 patch tokens,每个 patch token 在各自的局部窗口内计算 M2𝑀2 次)。巨大的 h×wℎ×𝑤 对 全局自注意力 计算而言是难以承受的,而 基于窗口的自注意力 (W_MSA) 则具有良好的扩展性。

2.2.2 在连续块中的移位窗口划分      

        基于窗口的自注意力模块 (W-MSA) 虽将计算复杂度从二次降为线性,但跨窗口之间交流与联系的匮乏将限制其建模表征能力。为引入跨窗口的联系且同时保持非重叠窗口的计算效率,我们提出一个 移位窗口划分方法,该方法在连续 Swin Transformer blocks 中的两种划分/分区配置间交替。

图 2

        如图 2 所示,首个模块使用一个规则的窗口划分策略,从左上角像素开始,将 8×88×8 特征图均匀划分为 2×22×2 个大小为 4×44×4 的窗口 (此时局部窗口尺寸为 M=4𝑀=4,如红色框所示)。然后,下个模块采用自前一层移位的窗口配置,即令规则划分窗口向左上 循环移位 (⌊M2⌋⌊𝑀2⌋, ⌊M2⌋⌊𝑀2⌋) 个像素,如上图的红色框位置变化所示。

        通过采用移位窗口划分方法,如上图的 两个连续 Swin Transformer Blocks 的计算可表示为:

        其中, z^lz^𝑙 和 zlz𝑙 分别表示第 l𝑙 个 block 的 (S)W-MSA 模块输出特征 和 MLP 模块输出特征 (如图 3 (b) 所示)。

        移位窗口划分方法引入了先前层非重叠相邻窗口间的联系,且对图像分类、目标检测和语义分割很有效,如表 4 所示。

2.2.3 便于移位配置的高效批计算

         一个关于移位窗口划分的问题是,从 ⌈hM⌉×⌈wM⌉⌈ℎ𝑀⌉×⌈𝑤𝑀⌉ 到 (⌈hM⌉+1)×(⌈wM⌉+1)(⌈ℎ𝑀⌉+1)×(⌈𝑤𝑀⌉+1) 不但会产生更多窗口,而且有些窗口尺寸将小于 M×M𝑀×𝑀。一个朴素的解决方法是,将更小的窗口填充至 M×M𝑀×𝑀,且在计算注意力时屏蔽掉填充值。当规则划分的窗口数很少时,即 2×22×2,由该朴素方法所带来的计算量增长是相当可观的 (2×2→3×32×2→3×3 大 2.25 倍)。

        此处,我们提出了一种更有效的批计算方法,其 循环向左上方移位,如图 4 所示。在这种移位后,批窗口可由特征图中不相邻的子窗口组成,因此 使用屏蔽机制将自注意计算限制在每个子窗口内。通过循环移位,批处理窗口的数仍与规则分区的窗口数相同 (如规则划分时是 44 个窗口,向左上角循环移位后仍是 44 个窗口,如上图的 A,B,C,D 所示)。因此,该方法是高效的,其低延迟率如表 5 所示。

        经过了循环移位的方法,一个窗口可包含来自不同窗口的内容。因此,要采用 masked MSA 机制将自注意力计算限制在各子窗口内。最后通过逆循环移位方法将每个窗口的自注意力结果返回。例如,一个 99 窗口的图解如下所示:

        按子窗口划分即可得到 5 号子窗口的自注意力的结果,但直接计算会使得 5 号 / 6 号 / 4 号子窗口的自注意力计算混在一起,类似的混算还包括 5 号 / 8 号 / 2 号子窗口 和  9 号 / 7 号 / 3 号 / 1 号子窗口的纵向或横向等。所以需采用 masked MSA 机制先正常计算自注意力,再进行 mask 操作将不需要的注意力图置 0,从而将自注意力计算限制在各子窗口内

         例如, 6 号 / 4 号子窗口共由 4 个 patch 构成一个正方形区域,如下所示,故应计算出 4×4 注意力图。

        为避免各不同的子窗口注意力计算发生混叠,合适的注意力图应如下所示:

         从而,合适的 mask 应如下所示:

         再例如,9 号 / 7 号 / 3 号 / 1 号子窗口共由 4 个 patch 构成一个正方形区域,如下所示:

         同理,合适的 mask 应如下所示:

2.2.4 相对位置偏置

        在计算自注意力时,我们在计算相似度的过程中对每个 head 加入 相对位置偏置 B∈RM2×M2𝐵∈𝑅𝑀2×𝑀2,如下所示:

        其中,Q,K,V∈RM2,d𝑄,𝐾,𝑉∈𝑅𝑀2,𝑑 分别为 QueryKey 和 Value 矩阵,d𝑑 为 Query Key 维度,M2𝑀2 为 (局部) 窗口内的 patches 数。因为沿各轴的相对位置均处于 [−M+1,M−1][−𝑀+1,𝑀−1] 范围内,我们参数化一个更小尺寸的偏置矩阵 B^∈R(2M−1)×(2M−1)𝐵^∈𝑅(2𝑀−1)×(2𝑀−1),且 B𝐵 中的值均取自 B^𝐵^。

        如表 4 的实验表明,使用该 相对位置偏置 的效果显著优于 不使用位置偏置 或 使用绝对位置嵌入。进一步向输入添加绝对位置嵌入会略微降低性能,因此在我们的实现中没有采用。 

        此外,预训练中学习到的相对位置偏置 也可用于 通过双三次插值 初始化具有不同窗口大小的微调模型。


2.3 架构变体 

        我们构造的基础模型 Swin-B 具有类似于 ViT-B/DeiT-B 的模型大小和计算复杂度。我们也引入了 Swin-T,Swin-S 和 Swin-L,其模型大小和计算复杂度分别是 Swin-B 的 0.25×0.25×,0.5×0.5× 和 2×2×。注意到,Swin-T 和 Swin-S 的复杂度分别与 ResNet-50 (DeiT-S) 和 ResNet-101 相似。每种架构的窗口尺寸均默认设为 M=7𝑀=7。对于所有实验,每个 head 的 Query 维度 d=32𝑑=32,且每个 MLP 的扩展层为 α=4𝛼=4。每种架构的各 Stage 层数如下:

        其中,C𝐶 是 Stage 1 的隐藏层通道数。用于 ImageNet 图像分类的各模型变体的模型大小、理论计算复杂度 (FLOPs) 和 吞吐量 (throughput) 如表 1 所示。

表 1:ImageNet 图像分类性能对比

表 7:模型架构细节


三、源码

        CodeGitHub - microsoft/Swin-Transformer: This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".

3.1 Swin Transformer

        先入为主地展示 Swin Transformer 的整体架构。

 
  1. class SwinTransformer(nn.Module):

  2. r""" Swin Transformer

  3. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -

  4. https://arxiv.org/pdf/2103.14030

  5. Args:

  6. img_size (int | tuple(int)): Input image size. Default 224

  7. patch_size (int | tuple(int)): Patch size. Default: 4

  8. in_chans (int): Number of input image channels. Default: 3

  9. num_classes (int): Number of classes for classification head. Default: 1000

  10. embed_dim (int): Patch embedding dimension. Default: 96

  11. depths (tuple(int)): Depth of each Swin Transformer layer.

  12. num_heads (tuple(int)): Number of attention heads in different layers.

  13. window_size (int): Window size. Default: 7

  14. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4

  15. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True

  16. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None

  17. drop_rate (float): Dropout rate. Default: 0

  18. attn_drop_rate (float): Attention dropout rate. Default: 0

  19. drop_path_rate (float): Stochastic depth rate. Default: 0.1

  20. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.

  21. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False

  22. patch_norm (bool): If True, add normalization after patch embedding. Default: True

  23. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False

  24. """

  25. def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,

  26. embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],

  27. window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,

  28. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,

  29. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,

  30. use_checkpoint=False, **kwargs):

  31. super().__init__()

  32. self.num_classes = num_classes

  33. self.num_layers = len(depths)

  34. self.embed_dim = embed_dim

  35. self.ape = ape

  36. self.patch_norm = patch_norm

  37. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

  38. self.mlp_ratio = mlp_ratio

  39. # split image into non-overlapping patches

  40. self.patch_embed = PatchEmbed(

  41. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,

  42. norm_layer=norm_layer if self.patch_norm else None)

  43. num_patches = self.patch_embed.num_patches

  44. patches_resolution = self.patch_embed.patches_resolution

  45. self.patches_resolution = patches_resolution

  46. # absolute position embedding

  47. if self.ape:

  48. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

  49. trunc_normal_(self.absolute_pos_embed, std=.02)

  50. self.pos_drop = nn.Dropout(p=drop_rate)

  51. # stochastic depth

  52. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule

  53. # build layers

  54. self.layers = nn.ModuleList()

  55. for i_layer in range(self.num_layers):

  56. layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),

  57. input_resolution=(patches_resolution[0] // (2 ** i_layer),

  58. patches_resolution[1] // (2 ** i_layer)),

  59. depth=depths[i_layer],

  60. num_heads=num_heads[i_layer],

  61. window_size=window_size,

  62. mlp_ratio=self.mlp_ratio,

  63. qkv_bias=qkv_bias, qk_scale=qk_scale,

  64. drop=drop_rate, attn_drop=attn_drop_rate,

  65. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],

  66. norm_layer=norm_layer,

  67. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,

  68. use_checkpoint=use_checkpoint)

  69. self.layers.append(layer)

  70. self.norm = norm_layer(self.num_features)

  71. self.avgpool = nn.AdaptiveAvgPool1d(1)

  72. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

  73. self.apply(self._init_weights)

  74. def _init_weights(self, m):

  75. if isinstance(m, nn.Linear):

  76. trunc_normal_(m.weight, std=.02)

  77. if isinstance(m, nn.Linear) and m.bias is not None:

  78. nn.init.constant_(m.bias, 0)

  79. elif isinstance(m, nn.LayerNorm):

  80. nn.init.constant_(m.bias, 0)

  81. nn.init.constant_(m.weight, 1.0)

  82. @torch.jit.ignore

  83. def no_weight_decay(self):

  84. return {'absolute_pos_embed'}

  85. @torch.jit.ignore

  86. def no_weight_decay_keywords(self):

  87. return {'relative_position_bias_table'}

  88. def forward_features(self, x):

  89. x = self.patch_embed(x)

  90. if self.ape:

  91. x = x + self.absolute_pos_embed

  92. x = self.pos_drop(x)

  93. for layer in self.layers:

  94. x = layer(x)

  95. x = self.norm(x) # B L C

  96. x = self.avgpool(x.transpose(1, 2)) # B C 1

  97. x = torch.flatten(x, 1)

  98. return x

  99. def forward(self, x):

  100. x = self.forward_features(x)

  101. x = self.head(x)

  102. return x

  103. def flops(self):

  104. flops = 0

  105. flops += self.patch_embed.flops()

  106. for i, layer in enumerate(self.layers):

  107. flops += layer.flops()

  108. flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)

  109. flops += self.num_features * self.num_classes

  110. return flops


3.2 Patch Embedding

        将图片输入 Swin Transformer Block 前,需将图片划分成若干 patch tokens 并投影为嵌入向量。更具体地,将输入原始图片划分成一个个 patch_size * patch_size 大小的 patch token,然后投影嵌入。可通过将 2D 卷积层的 stride 和 kernel_size 的大小设为 patch_size,并将输出通道数设为 embed_dim 来实现投影嵌入。最后,展平并置换维度。

 
  1. class PatchEmbed(nn.Module):

  2. r""" Image to Patch Embedding

  3. Args:

  4. img_size (int): Image size. Default: 224.

  5. patch_size (int): Patch token size. Default: 4.

  6. in_chans (int): Number of input image channels. Default: 3.

  7. embed_dim (int): Number of linear projection output channels. Default: 96.

  8. norm_layer (nn.Module, optional): Normalization layer. Default: None

  9. """

  10. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):

  11. super().__init__()

  12. img_size = to_2tuple(img_size)

  13. patch_size = to_2tuple(patch_size)

  14. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]

  15. self.img_size = img_size

  16. self.patch_size = patch_size

  17. self.patches_resolution = patches_resolution

  18. self.num_patches = patches_resolution[0] * patches_resolution[1]

  19. self.in_chans = in_chans

  20. self.embed_dim = embed_dim

  21. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 输入嵌入投影

  22. if norm_layer is not None:

  23. self.norm = norm_layer(embed_dim)

  24. else:

  25. self.norm = None

  26. def forward(self, x):

  27. '''

  28. # 以默认参数为例 # 输入 (B, C, H, W) = (B, 3, 224, 224)

  29. x = self.proj(x) # 输出 (B, 96, 224/4, 224/4) = (B, 96, 56, 56)

  30. x = torch.flatten(x, 2) # H W 维展平, 输出 (B, 96, 56*56)

  31. x = torch.transpose(x, 1, 2) # C 维放最后, 输出 (B, 56*56, 96)

  32. '''

  33. B, C, H, W = x.shape

  34. # FIXME look at relaxing size constraints

  35. assert H == self.img_size[0] and W == self.img_size[1], \

  36. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

  37. x = self.proj(x).flatten(2).transpose(1, 2) # shape = (B, P_h*P_w, C)

  38. if self.norm is not None:

  39. x = self.norm(x)

  40. return x

  41. def flops(self):

  42. Ho, Wo = self.patches_resolution

  43. flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])

  44. if self.norm is not None:

  45. flops += Ho * Wo * self.embed_dim

  46. return flops


3.3 Patch Merging

        在每个 Stage 前下采样缩小分辨率并减半通道数,从而形成层次化设计并降低运算量 (类似 Pixel Shuffle)。示意图及实现:

 
  1. class PatchMerging(nn.Module):

  2. r""" Patch Merging Layer.

  3. Args:

  4. input_resolution (tuple[int]): Resolution of input feature.

  5. dim (int): Number of input channels.

  6. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

  7. """

  8. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):

  9. super().__init__()

  10. self.input_resolution = input_resolution

  11. self.dim = dim

  12. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

  13. self.norm = norm_layer(4 * dim)

  14. def forward(self, x):

  15. """

  16. x: B, H*W, C

  17. """

  18. H, W = self.input_resolution

  19. B, L, C = x.shape

  20. assert L == H * W, "input feature has wrong size"

  21. assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

  22. # reshape

  23. x = x.view(B, H, W, C)

  24. # 在行、列方向以 stride = 2 等间隔抽样, 实现分辨率 1/2 下采样

  25. x0 = x[:, 0::2, 0::2, :] # shape = (B, H/2, W/2, C)

  26. x1 = x[:, 1::2, 0::2, :] # shape = (B, H/2, W/2, C)

  27. x2 = x[:, 0::2, 1::2, :] # shape = (B, H/2, W/2, C)

  28. x3 = x[:, 1::2, 1::2, :] # shape = (B, H/2, W/2, C)

  29. # 拼接 使通道数加倍

  30. x = torch.cat([x0, x1, x2, x3], -1) # shape = (B, H/2, W/2, 4*C)

  31. x = x.view(B, -1, 4 * C) # shape = (B, H*W/4, 4*C)

  32. # FC 使通道数减半

  33. x = self.norm(x)

  34. x = self.reduction(x) # shape = (B, H*W/4, 2*C)

  35. return x

  36. def extra_repr(self) -> str:

  37. return f"input_resolution={self.input_resolution}, dim={self.dim}"

  38. def flops(self):

  39. H, W = self.input_resolution

  40. flops = H * W * self.dim

  41. flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim

  42. return flops


3.4 Window Partition

        将 shape = (B,H,W,C)(𝐵,𝐻,𝑊,𝐶) 的输入张量 reshape 为 shape = (B×HM×WM,M,M,C)(𝐵×𝐻𝑀×𝑊𝑀,𝑀,𝑀,𝐶) 的窗口张量。其中 M𝑀 即为窗口大小。由此,得到 N=B×HM×WM𝑁=𝐵×𝐻𝑀×𝑊𝑀 个 shape = (M,M,C)(𝑀,𝑀,𝐶) 的窗口。该函数将用于 Window Attention

 
  1. def window_partition(x, window_size):

  2. """

  3. Args:

  4. x: (B, H, W, C)

  5. window_size (int): window size

  6. Returns:

  7. windows: (num_windows*B, window_size, window_size, C)

  8. """

  9. B, H, W, C = x.shape

  10. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)

  11. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)

  12. return windows


3.5 Window Reverse

        即窗口划分的逆过程,将 shape = (B×HM×WM,M,M,C)(𝐵×𝐻𝑀×𝑊𝑀,𝑀,𝑀,𝐶) 的窗口张量 reshape 回 shape = (B,H,W,C)(𝐵,𝐻,𝑊,𝐶) 的张量。该函数将用于 Window Attention

 
  1. def window_reverse(windows, window_size, H, W):

  2. """

  3. Args:

  4. windows: (num_windows*B, window_size, window_size, C)

  5. window_size (int): Window size

  6. H (int): Height of image

  7. W (int): Width of image

  8. Returns:

  9. x: (B, H, W, C)

  10. """

  11. B = int(windows.shape[0] / (H * W / window_size / window_size))

  12. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)

  13. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)

  14. return x


3.6 MLP

        使用 GELU 激活函数 + Dropout 的两层 FCs。

 
  1. class Mlp(nn.Module):

  2. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):

  3. super().__init__()

  4. out_features = out_features or in_features

  5. hidden_features = hidden_features or in_features

  6. self.fc1 = nn.Linear(in_features, hidden_features)

  7. self.act = act_layer()

  8. self.fc2 = nn.Linear(hidden_features, out_features)

  9. self.drop = nn.Dropout(drop)

  10. def forward(self, x):

  11. x = self.fc1(x)

  12. x = self.act(x)

  13. x = self.drop(x)

  14. x = self.fc2(x)

  15. x = self.drop(x)

  16. return x


3.7 Window Attention (W-MSA Module) ☆

        一方面,在局部窗口 而非全局图像内 计算自注意力 可将计算复杂度由二次降为线性。

        另一方面,在计算原 Attention 的 Query 和 Key 时,加入 相对位置编码 B𝐵 可改善性能。

        更具体地,首先由 Query 和 Key 相乘得到 Attention Map,shape = (numWindows*B, num_heads, window_size*window_size, window_size*window_size)。对于 Attention Map,以不同像素点作为原点,则各像素点位置/坐标随之不同。

        由于 每个非重叠局部窗口都包含 N = M × M 个 patch tokens, window_size = M = 2 为例,分别以 左上角像素点 和 右上角像素点 为原点的相对位置编码如下所示 (坐标系轴向同矩阵坐标):

         其次,使用 torch.arange 生成等距的行方向和列方向索引,再用 torch.meshgrid 生成网格坐标索引。

        仍以 window_size = M = 2 为例,生成网格 grid 坐标:

 
  1. coords_h = torch.arange(self.window_size[0])

  2. coords_w = torch.arange(self.window_size[1])

  3. coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)

  4. """

  5. (tensor([[0, 0],

  6. [1, 1]]),

  7. tensor([[0, 1],

  8. [0, 1]]))

  9. """

         堆叠并展开为 2D 向量:

 
  1. coords = torch.stack(coords) # 2, Wh, Ww

  2. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww

  3. """

  4. tensor([[0, 0, 1, 1],

  5. [0, 1, 0, 1]])

  6. """

         分别在第 1 和 2 维处插入新维度,并利用广播机制做减法,得到 shape = (2, wh*ww, wh*ww) 的张量:

 
  1. relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1

  2. relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww

  3. relative_coords = relative_coords_first - relative_coords_second # 2, wh*ww, wh*ww

        由于相减得到的索引是从负数开始的,故加上偏移量使之从 0 开始:

 
  1. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

  2. relative_coords[:, :, 0] += self.window_size[0] - 1

  3. relative_coords[:, :, 1] += self.window_size[1] - 1

        接着,需要将其展开成 1D 偏移量。

        对于诸如第 0 行上 (1, 2) 和 (2, 1) 这两个不同的坐标 (x, y),通过将 (x, y) 坐标求和得到 1D 偏移量 x+y 时,二者所表示的 相对于原点的偏移量却是相等的 (1+2 = 2+1 = 3):

可见第 0 行的原始偏移量 x+y 分别为 2、3、3、4,不同的位置却具有相同的偏移量,降低了相对区分度/差异度

         为避免这种 偏移量相等 的错误对应情况,还需对坐标 (准确地说是 x 坐标) 进行 乘法变换 (offset multiply),以提高区分度:

relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)  # 每个 x 坐标乘 (2 * 2 - 1) = 3

对 x 坐标实施乘法变换得到 (x', y),再重新计算得到具有差异度的各坐标位置的偏移量 x'+y

         接着在最后一维上求和 x+y,展开成一个 1D 坐标 (相对位置索引),并注册为一个不参与网络学习的变量 relative_position_index,其作用是 根据最终的相对位置索引 找到对应的可学习的相对位置编码

 
  1. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww

  2. self.register_buffer("relative_position_index", relative_position_index)

         完整代码如下所示:

 
  1. class WindowAttention(nn.Module):

  2. r""" Window based multi-head self attention (W-MSA) module with relative position bias.

  3. It supports both of shifted and non-shifted window.

  4. Args:

  5. dim (int): Number of input channels.

  6. window_size (tuple[int]): The height and width of the window.

  7. num_heads (int): Number of attention heads.

  8. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

  9. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set

  10. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0

  11. proj_drop (float, optional): Dropout ratio of output. Default: 0.0

  12. """

  13. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

  14. super().__init__()

  15. self.dim = dim

  16. self.window_size = window_size # Wh, Ww

  17. self.num_heads = num_heads

  18. head_dim = dim // num_heads

  19. self.scale = qk_scale or head_dim ** -0.5

  20. # define a parameter table of relative position bias

  21. self.relative_position_bias_table = nn.Parameter(

  22. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH

  23. # get pair-wise relative position index for each token inside the window

  24. coords_h = torch.arange(self.window_size[0]) # 局部窗口高度方向坐标

  25. coords_w = torch.arange(self.window_size[1]) # 局部窗口宽度方向坐标

  26. # 局部窗口坐标网格

  27. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww

  28. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww

  29. # 相对位置

  30. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww

  31. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

  32. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0

  33. relative_coords[:, :, 1] += self.window_size[1] - 1

  34. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

  35. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww

  36. self.register_buffer("relative_position_index", relative_position_index)

  37. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

  38. self.attn_drop = nn.Dropout(attn_drop)

  39. self.proj = nn.Linear(dim, dim)

  40. self.proj_drop = nn.Dropout(proj_drop)

  41. trunc_normal_(self.relative_position_bias_table, std=.02)

  42. self.softmax = nn.Softmax(dim=-1)

  43. def forward(self, x, mask=None):

  44. """

  45. Args:

  46. x: input features with shape of (num_windows*B, N, C)

  47. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None

  48. """

  49. B_, N, C = x.shape

  50. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

  51. # Query, Key, Value

  52. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

  53. # Query 放缩

  54. q = q * self.scale

  55. # Query * Key

  56. attn = (q @ k.transpose(-2, -1)) # @ 表示矩阵-向量乘法

  57. # 相对位置偏置 B

  58. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(

  59. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH

  60. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww

  61. # Attention Map = Softmax(Q * K / √d + B)

  62. attn = attn + relative_position_bias.unsqueeze(0)

  63. # 局部窗口 attention map mask + Softmax

  64. if mask is not None:

  65. nW = mask.shape[0]

  66. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)

  67. attn = attn.view(-1, self.num_heads, N, N)

  68. attn = self.softmax(attn) # 最终的 Attention Map

  69. else:

  70. attn = self.softmax(attn) # 最终的 Attention Map

  71. attn = self.attn_drop(attn)

  72. # Attention Map * V

  73. x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # @ 表示矩阵-向量乘法

  74. # 线性投影 FC

  75. x = self.proj(x)

  76. x = self.proj_drop(x)

  77. return x

  78. def extra_repr(self) -> str:

  79. ### 用于输出 print 结果

  80. return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

  81. def flops(self, N):

  82. ### calculate flops for 1 window with token length of N

  83. flops = 0

  84. # qkv = self.qkv(x)

  85. flops += N * self.dim * 3 * self.dim

  86. # attn = (q @ k.transpose(-2, -1))

  87. flops += self.num_heads * N * (self.dim // self.num_heads) * N

  88. # x = (attn @ v)

  89. flops += self.num_heads * N * N * (self.dim // self.num_heads)

  90. # x = self.proj(x)

  91. flops += N * self.dim * self.dim

  92. return flops

        一个不错的整体流程示意图:

相对位置编码矩阵:每一列 代表 每一个坐标在所有坐标 “眼中” 的相对位置

        主要 shape 变化注释:

 
  1. class WindowAttention(nn.Module):

  2. r""" Window based multi-head self attention (W-MSA) module with relative position bias.

  3. It supports both of shifted and non-shifted window.

  4. Args:

  5. dim (int): Number of input channels.

  6. window_size (tuple[int]): The height and width of the window.

  7. num_heads (int): Number of attention heads.

  8. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

  9. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set

  10. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0

  11. proj_drop (float, optional): Dropout ratio of output. Default: 0.0

  12. """

  13. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

  14. super().__init__()

  15. self.dim = dim

  16. self.window_size = window_size # 通常默认 wh = ww = w = 4

  17. self.num_heads = num_heads # MHA 的头数

  18. head_dim = dim // num_heads # dim 平均分给每个 head

  19. self.scale = qk_scale or head_dim ** -0.5 # MHA 内的 scale 分母: 自定义的 qk_scale 或 根号 d

  20. # define a parameter table of relative position bias

  21. self.relative_position_bias_table = nn.Parameter(

  22. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # (2*wh-1 * 2*ww-1, num_heads)

  23. # get pair-wise relative position index for each token inside the window

  24. coords_h = torch.arange(self.window_size[0]) # wh

  25. coords_w = torch.arange(self.window_size[1]) # ww

  26. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # (2, wh, ww)

  27. coords_flatten = torch.flatten(coords, 1) # (2, wh*ww)

  28. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # (2, wh*ww, wh*ww)

  29. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (wh*ww, wh*ww, 2)

  30. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0

  31. relative_coords[:, :, 1] += self.window_size[1] - 1

  32. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

  33. relative_position_index = relative_coords.sum(-1) # (wh*ww, wh*ww)

  34. self.register_buffer("relative_position_index", relative_position_index)

  35. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

  36. self.attn_drop = nn.Dropout(attn_drop)

  37. self.proj = nn.Linear(dim, dim)

  38. self.proj_drop = nn.Dropout(proj_drop)

  39. trunc_normal_(self.relative_position_bias_table, std=.02)

  40. self.softmax = nn.Softmax(dim=-1)

  41. def forward(self, x, mask=None):

  42. """

  43. Args:

  44. x: input features with shape of (num_windows*B, N, C)

  45. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None

  46. """

  47. # 默认 N = wh*ww = w*w = 16

  48. # 默认 num_windows = (H*W)//(wh*ww) = (H*W)//16

  49. # 默认 C = 3

  50. # (num_windows*B, N, C) = (num_windows*B, wh*ww, C)

  51. B_, N, C = x.shape

  52. # (num_windows*B, N, C, num_heads, C//num_heads) -> (C, num_windows*B, num_heads, wh*ww, C//num_heads)

  53. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

  54. # (num_windows*B, num_heads, wh*ww, C//num_heads)

  55. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

  56. # (num_windows*B, num_heads, wh*ww, C//num_heads)

  57. q = q * self.scale

  58. # (num_windows*B, num_heads, wh*ww, C//num_heads) * (num_windows*B, num_heads, C//num_heads, wh*ww) = (num_windows*B, num_heads, wh*ww, wh*ww)

  59. attn = (q @ k.transpose(-2, -1))

  60. # (wh*ww, wh*ww, num_heads)

  61. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(

  62. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)

  63. # (num_heads, wh*ww, wh*ww)

  64. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()

  65. # (num_heads, wh*ww, wh*ww) -> (1, num_heads, wh*ww, wh*ww) -> (num_windows*B, num_heads, wh*ww, wh*ww)

  66. attn = attn + relative_position_bias.unsqueeze(0) #

  67. if mask is not None:

  68. nW = mask.shape[0]

  69. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)

  70. attn = attn.view(-1, self.num_heads, N, N)

  71. attn = self.softmax(attn)

  72. else:

  73. attn = self.softmax(attn)

  74. # (num_windows*B, num_heads, wh*ww, wh*ww)

  75. attn = self.attn_drop(attn)

  76. # (num_windows*B, num_heads, wh*ww, wh*ww) * (num_windows*B, num_heads, wh*ww, C//num_heads) = (num_windows*B, num_heads, wh*ww, C//num_heads)

  77. # (num_windows*B, num_heads, wh*ww, C//num_heads) -> (num_windows*B, wh*ww, num_heads, C//num_heads) -> (num_windows*B, wh*ww, C) = (N*B, wh*ww, C)

  78. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

  79. x = self.proj(x)

  80. x = self.proj_drop(x)

  81. return x

  82. def extra_repr(self) -> str:

  83. return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

  84. def flops(self, N):

  85. # calculate flops for 1 window with token length of N

  86. flops = 0

  87. # qkv = self.qkv(x)

  88. flops += N * self.dim * 3 * self.dim

  89. # attn = (q @ k.transpose(-2, -1))

  90. flops += self.num_heads * N * (self.dim // self.num_heads) * N

  91. # x = (attn @ v)

  92. flops += self.num_heads * N * N * (self.dim // self.num_heads)

  93. # x = self.proj(x)

  94. flops += N * self.dim * self.dim

  95. return flops

  96. [点击并拖拽以移动]

        主要 shape 变化演示:

 
  1. import torch

  2. import torch.nn as nn

  3. # 以 4×4 窗口大小为例

  4. window_size = (4, 4)

  5. coords_h = torch.arange(window_size[0]) # wh

  6. coords_w = torch.arange(window_size[1]) # ww

  7. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # (2, wh, ww)

  8. coords, coords.shape

 
  1. (tensor([[[0, 0, 0, 0],

  2. [1, 1, 1, 1],

  3. [2, 2, 2, 2],

  4. [3, 3, 3, 3]],

  5. [[0, 1, 2, 3],

  6. [0, 1, 2, 3],

  7. [0, 1, 2, 3],

  8. [0, 1, 2, 3]]]),

  9. torch.Size([2, 4, 4]))

 
  1. coords_flatten = torch.flatten(coords, 1) # (2, wh*ww)

  2. coords_flatten, coords_flatten.shape

 
  1. (tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],

  2. [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),

  3. torch.Size([2, 16]))

 
  1. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # (2, wh*ww, wh*ww)

  2. relative_coords, relative_coords.shape

 
  1. (tensor([[[ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],

  2. [ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],

  3. [ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],

  4. [ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],

  5. [ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],

  6. [ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],

  7. [ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],

  8. [ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],

  9. [ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],

  10. [ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],

  11. [ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],

  12. [ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],

  13. [ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],

  14. [ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],

  15. [ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],

  16. [ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0]],

  17. [[ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],

  18. [ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],

  19. [ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],

  20. [ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],

  21. [ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],

  22. [ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],

  23. [ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],

  24. [ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],

  25. [ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],

  26. [ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],

  27. [ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],

  28. [ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],

  29. [ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],

  30. [ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],

  31. [ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],

  32. [ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0]]]),

  33. torch.Size([2, 16, 16]))

 
  1. # (x, y) 格式显示 横、纵坐标

  2. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (wh*ww, wh*ww, 2)

  3. relative_coords, relative_coords.shape

 
  1. (tensor([[[ 0, 0],

  2. [ 0, -1],

  3. [ 0, -2],

  4. [ 0, -3],

  5. [-1, 0],

  6. [-1, -1],

  7. [-1, -2],

  8. [-1, -3],

  9. [-2, 0],

  10. [-2, -1],

  11. [-2, -2],

  12. [-2, -3],

  13. [-3, 0],

  14. [-3, -1],

  15. [-3, -2],

  16. [-3, -3]],

  17. [[ 0, 1],

  18. [ 0, 0],

  19. [ 0, -1],

  20. [ 0, -2],

  21. [-1, 1],

  22. [-1, 0],

  23. [-1, -1],

  24. [-1, -2],

  25. [-2, 1],

  26. [-2, 0],

  27. [-2, -1],

  28. [-2, -2],

  29. [-3, 1],

  30. [-3, 0],

  31. [-3, -1],

  32. [-3, -2]],

  33. [[ 0, 2],

  34. [ 0, 1],

  35. [ 0, 0],

  36. [ 0, -1],

  37. [-1, 2],

  38. [-1, 1],

  39. [-1, 0],

  40. [-1, -1],

  41. [-2, 2],

  42. [-2, 1],

  43. [-2, 0],

  44. [-2, -1],

  45. [-3, 2],

  46. [-3, 1],

  47. [-3, 0],

  48. [-3, -1]],

  49. [[ 0, 3],

  50. [ 0, 2],

  51. [ 0, 1],

  52. [ 0, 0],

  53. [-1, 3],

  54. [-1, 2],

  55. [-1, 1],

  56. [-1, 0],

  57. [-2, 3],

  58. [-2, 2],

  59. [-2, 1],

  60. [-2, 0],

  61. [-3, 3],

  62. [-3, 2],

  63. [-3, 1],

  64. [-3, 0]],

  65. [[ 1, 0],

  66. [ 1, -1],

  67. [ 1, -2],

  68. [ 1, -3],

  69. [ 0, 0],

  70. [ 0, -1],

  71. [ 0, -2],

  72. [ 0, -3],

  73. [-1, 0],

  74. [-1, -1],

  75. [-1, -2],

  76. [-1, -3],

  77. [-2, 0],

  78. [-2, -1],

  79. [-2, -2],

  80. [-2, -3]],

  81. [[ 1, 1],

  82. [ 1, 0],

  83. [ 1, -1],

  84. [ 1, -2],

  85. [ 0, 1],

  86. [ 0, 0],

  87. [ 0, -1],

  88. [ 0, -2],

  89. [-1, 1],

  90. [-1, 0],

  91. [-1, -1],

  92. [-1, -2],

  93. [-2, 1],

  94. [-2, 0],

  95. [-2, -1],

  96. [-2, -2]],

  97. [[ 1, 2],

  98. [ 1, 1],

  99. [ 1, 0],

  100. [ 1, -1],

  101. [ 0, 2],

  102. [ 0, 1],

  103. [ 0, 0],

  104. [ 0, -1],

  105. [-1, 2],

  106. [-1, 1],

  107. [-1, 0],

  108. [-1, -1],

  109. [-2, 2],

  110. [-2, 1],

  111. [-2, 0],

  112. [-2, -1]],

  113. [[ 1, 3],

  114. [ 1, 2],

  115. [ 1, 1],

  116. [ 1, 0],

  117. [ 0, 3],

  118. [ 0, 2],

  119. [ 0, 1],

  120. [ 0, 0],

  121. [-1, 3],

  122. [-1, 2],

  123. [-1, 1],

  124. [-1, 0],

  125. [-2, 3],

  126. [-2, 2],

  127. [-2, 1],

  128. [-2, 0]],

  129. [[ 2, 0],

  130. [ 2, -1],

  131. [ 2, -2],

  132. [ 2, -3],

  133. [ 1, 0],

  134. [ 1, -1],

  135. [ 1, -2],

  136. [ 1, -3],

  137. [ 0, 0],

  138. [ 0, -1],

  139. [ 0, -2],

  140. [ 0, -3],

  141. [-1, 0],

  142. [-1, -1],

  143. [-1, -2],

  144. [-1, -3]],

  145. [[ 2, 1],

  146. [ 2, 0],

  147. [ 2, -1],

  148. [ 2, -2],

  149. [ 1, 1],

  150. [ 1, 0],

  151. [ 1, -1],

  152. [ 1, -2],

  153. [ 0, 1],

  154. [ 0, 0],

  155. [ 0, -1],

  156. [ 0, -2],

  157. [-1, 1],

  158. [-1, 0],

  159. [-1, -1],

  160. [-1, -2]],

  161. [[ 2, 2],

  162. [ 2, 1],

  163. [ 2, 0],

  164. [ 2, -1],

  165. [ 1, 2],

  166. [ 1, 1],

  167. [ 1, 0],

  168. [ 1, -1],

  169. [ 0, 2],

  170. [ 0, 1],

  171. [ 0, 0],

  172. [ 0, -1],

  173. [-1, 2],

  174. [-1, 1],

  175. [-1, 0],

  176. [-1, -1]],

  177. [[ 2, 3],

  178. [ 2, 2],

  179. [ 2, 1],

  180. [ 2, 0],

  181. [ 1, 3],

  182. [ 1, 2],

  183. [ 1, 1],

  184. [ 1, 0],

  185. [ 0, 3],

  186. [ 0, 2],

  187. [ 0, 1],

  188. [ 0, 0],

  189. [-1, 3],

  190. [-1, 2],

  191. [-1, 1],

  192. [-1, 0]],

  193. [[ 3, 0],

  194. [ 3, -1],

  195. [ 3, -2],

  196. [ 3, -3],

  197. [ 2, 0],

  198. [ 2, -1],

  199. [ 2, -2],

  200. [ 2, -3],

  201. [ 1, 0],

  202. [ 1, -1],

  203. [ 1, -2],

  204. [ 1, -3],

  205. [ 0, 0],

  206. [ 0, -1],

  207. [ 0, -2],

  208. [ 0, -3]],

  209. [[ 3, 1],

  210. [ 3, 0],

  211. [ 3, -1],

  212. [ 3, -2],

  213. [ 2, 1],

  214. [ 2, 0],

  215. [ 2, -1],

  216. [ 2, -2],

  217. [ 1, 1],

  218. [ 1, 0],

  219. [ 1, -1],

  220. [ 1, -2],

  221. [ 0, 1],

  222. [ 0, 0],

  223. [ 0, -1],

  224. [ 0, -2]],

  225. [[ 3, 2],

  226. [ 3, 1],

  227. [ 3, 0],

  228. [ 3, -1],

  229. [ 2, 2],

  230. [ 2, 1],

  231. [ 2, 0],

  232. [ 2, -1],

  233. [ 1, 2],

  234. [ 1, 1],

  235. [ 1, 0],

  236. [ 1, -1],

  237. [ 0, 2],

  238. [ 0, 1],

  239. [ 0, 0],

  240. [ 0, -1]],

  241. [[ 3, 3],

  242. [ 3, 2],

  243. [ 3, 1],

  244. [ 3, 0],

  245. [ 2, 3],

  246. [ 2, 2],

  247. [ 2, 1],

  248. [ 2, 0],

  249. [ 1, 3],

  250. [ 1, 2],

  251. [ 1, 1],

  252. [ 1, 0],

  253. [ 0, 3],

  254. [ 0, 2],

  255. [ 0, 1],

  256. [ 0, 0]]]),

  257. torch.Size([16, 16, 2]))

 
  1. # 横坐标加性偏移 (+= 3)

  2. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0

  3. relative_coords

 
  1. tensor([[[ 3, 0],

  2. [ 3, -1],

  3. [ 3, -2],

  4. [ 3, -3],

  5. [ 2, 0],

  6. [ 2, -1],

  7. [ 2, -2],

  8. [ 2, -3],

  9. [ 1, 0],

  10. [ 1, -1],

  11. [ 1, -2],

  12. [ 1, -3],

  13. [ 0, 0],

  14. [ 0, -1],

  15. [ 0, -2],

  16. [ 0, -3]],

  17. [[ 3, 1],

  18. [ 3, 0],

  19. [ 3, -1],

  20. [ 3, -2],

  21. [ 2, 1],

  22. [ 2, 0],

  23. [ 2, -1],

  24. [ 2, -2],

  25. [ 1, 1],

  26. [ 1, 0],

  27. [ 1, -1],

  28. [ 1, -2],

  29. [ 0, 1],

  30. [ 0, 0],

  31. [ 0, -1],

  32. [ 0, -2]],

  33. [[ 3, 2],

  34. [ 3, 1],

  35. [ 3, 0],

  36. [ 3, -1],

  37. [ 2, 2],

  38. [ 2, 1],

  39. [ 2, 0],

  40. [ 2, -1],

  41. [ 1, 2],

  42. [ 1, 1],

  43. [ 1, 0],

  44. [ 1, -1],

  45. [ 0, 2],

  46. [ 0, 1],

  47. [ 0, 0],

  48. [ 0, -1]],

  49. [[ 3, 3],

  50. [ 3, 2],

  51. [ 3, 1],

  52. [ 3, 0],

  53. [ 2, 3],

  54. [ 2, 2],

  55. [ 2, 1],

  56. [ 2, 0],

  57. [ 1, 3],

  58. [ 1, 2],

  59. [ 1, 1],

  60. [ 1, 0],

  61. [ 0, 3],

  62. [ 0, 2],

  63. [ 0, 1],

  64. [ 0, 0]],

  65. [[ 4, 0],

  66. [ 4, -1],

  67. [ 4, -2],

  68. [ 4, -3],

  69. [ 3, 0],

  70. [ 3, -1],

  71. [ 3, -2],

  72. [ 3, -3],

  73. [ 2, 0],

  74. [ 2, -1],

  75. [ 2, -2],

  76. [ 2, -3],

  77. [ 1, 0],

  78. [ 1, -1],

  79. [ 1, -2],

  80. [ 1, -3]],

  81. [[ 4, 1],

  82. [ 4, 0],

  83. [ 4, -1],

  84. [ 4, -2],

  85. [ 3, 1],

  86. [ 3, 0],

  87. [ 3, -1],

  88. [ 3, -2],

  89. [ 2, 1],

  90. [ 2, 0],

  91. [ 2, -1],

  92. [ 2, -2],

  93. [ 1, 1],

  94. [ 1, 0],

  95. [ 1, -1],

  96. [ 1, -2]],

  97. [[ 4, 2],

  98. [ 4, 1],

  99. [ 4, 0],

  100. [ 4, -1],

  101. [ 3, 2],

  102. [ 3, 1],

  103. [ 3, 0],

  104. [ 3, -1],

  105. [ 2, 2],

  106. [ 2, 1],

  107. [ 2, 0],

  108. [ 2, -1],

  109. [ 1, 2],

  110. [ 1, 1],

  111. [ 1, 0],

  112. [ 1, -1]],

  113. [[ 4, 3],

  114. [ 4, 2],

  115. [ 4, 1],

  116. [ 4, 0],

  117. [ 3, 3],

  118. [ 3, 2],

  119. [ 3, 1],

  120. [ 3, 0],

  121. [ 2, 3],

  122. [ 2, 2],

  123. [ 2, 1],

  124. [ 2, 0],

  125. [ 1, 3],

  126. [ 1, 2],

  127. [ 1, 1],

  128. [ 1, 0]],

  129. [[ 5, 0],

  130. [ 5, -1],

  131. [ 5, -2],

  132. [ 5, -3],

  133. [ 4, 0],

  134. [ 4, -1],

  135. [ 4, -2],

  136. [ 4, -3],

  137. [ 3, 0],

  138. [ 3, -1],

  139. [ 3, -2],

  140. [ 3, -3],

  141. [ 2, 0],

  142. [ 2, -1],

  143. [ 2, -2],

  144. [ 2, -3]],

  145. [[ 5, 1],

  146. [ 5, 0],

  147. [ 5, -1],

  148. [ 5, -2],

  149. [ 4, 1],

  150. [ 4, 0],

  151. [ 4, -1],

  152. [ 4, -2],

  153. [ 3, 1],

  154. [ 3, 0],

  155. [ 3, -1],

  156. [ 3, -2],

  157. [ 2, 1],

  158. [ 2, 0],

  159. [ 2, -1],

  160. [ 2, -2]],

  161. [[ 5, 2],

  162. [ 5, 1],

  163. [ 5, 0],

  164. [ 5, -1],

  165. [ 4, 2],

  166. [ 4, 1],

  167. [ 4, 0],

  168. [ 4, -1],

  169. [ 3, 2],

  170. [ 3, 1],

  171. [ 3, 0],

  172. [ 3, -1],

  173. [ 2, 2],

  174. [ 2, 1],

  175. [ 2, 0],

  176. [ 2, -1]],

  177. [[ 5, 3],

  178. [ 5, 2],

  179. [ 5, 1],

  180. [ 5, 0],

  181. [ 4, 3],

  182. [ 4, 2],

  183. [ 4, 1],

  184. [ 4, 0],

  185. [ 3, 3],

  186. [ 3, 2],

  187. [ 3, 1],

  188. [ 3, 0],

  189. [ 2, 3],

  190. [ 2, 2],

  191. [ 2, 1],

  192. [ 2, 0]],

  193. [[ 6, 0],

  194. [ 6, -1],

  195. [ 6, -2],

  196. [ 6, -3],

  197. [ 5, 0],

  198. [ 5, -1],

  199. [ 5, -2],

  200. [ 5, -3],

  201. [ 4, 0],

  202. [ 4, -1],

  203. [ 4, -2],

  204. [ 4, -3],

  205. [ 3, 0],

  206. [ 3, -1],

  207. [ 3, -2],

  208. [ 3, -3]],

  209. [[ 6, 1],

  210. [ 6, 0],

  211. [ 6, -1],

  212. [ 6, -2],

  213. [ 5, 1],

  214. [ 5, 0],

  215. [ 5, -1],

  216. [ 5, -2],

  217. [ 4, 1],

  218. [ 4, 0],

  219. [ 4, -1],

  220. [ 4, -2],

  221. [ 3, 1],

  222. [ 3, 0],

  223. [ 3, -1],

  224. [ 3, -2]],

  225. [[ 6, 2],

  226. [ 6, 1],

  227. [ 6, 0],

  228. [ 6, -1],

  229. [ 5, 2],

  230. [ 5, 1],

  231. [ 5, 0],

  232. [ 5, -1],

  233. [ 4, 2],

  234. [ 4, 1],

  235. [ 4, 0],

  236. [ 4, -1],

  237. [ 3, 2],

  238. [ 3, 1],

  239. [ 3, 0],

  240. [ 3, -1]],

  241. [[ 6, 3],

  242. [ 6, 2],

  243. [ 6, 1],

  244. [ 6, 0],

  245. [ 5, 3],

  246. [ 5, 2],

  247. [ 5, 1],

  248. [ 5, 0],

  249. [ 4, 3],

  250. [ 4, 2],

  251. [ 4, 1],

  252. [ 4, 0],

  253. [ 3, 3],

  254. [ 3, 2],

  255. [ 3, 1],

  256. [ 3, 0]]])

 
  1. # 纵坐标加性偏移 (+= 3)

  2. relative_coords[:, :, 1] += window_size[1] - 1

  3. relative_coords

 
  1. tensor([[[3, 3],

  2. [3, 2],

  3. [3, 1],

  4. [3, 0],

  5. [2, 3],

  6. [2, 2],

  7. [2, 1],

  8. [2, 0],

  9. [1, 3],

  10. [1, 2],

  11. [1, 1],

  12. [1, 0],

  13. [0, 3],

  14. [0, 2],

  15. [0, 1],

  16. [0, 0]],

  17. [[3, 4],

  18. [3, 3],

  19. [3, 2],

  20. [3, 1],

  21. [2, 4],

  22. [2, 3],

  23. [2, 2],

  24. [2, 1],

  25. [1, 4],

  26. [1, 3],

  27. [1, 2],

  28. [1, 1],

  29. [0, 4],

  30. [0, 3],

  31. [0, 2],

  32. [0, 1]],

  33. [[3, 5],

  34. [3, 4],

  35. [3, 3],

  36. [3, 2],

  37. [2, 5],

  38. [2, 4],

  39. [2, 3],

  40. [2, 2],

  41. [1, 5],

  42. [1, 4],

  43. [1, 3],

  44. [1, 2],

  45. [0, 5],

  46. [0, 4],

  47. [0, 3],

  48. [0, 2]],

  49. [[3, 6],

  50. [3, 5],

  51. [3, 4],

  52. [3, 3],

  53. [2, 6],

  54. [2, 5],

  55. [2, 4],

  56. [2, 3],

  57. [1, 6],

  58. [1, 5],

  59. [1, 4],

  60. [1, 3],

  61. [0, 6],

  62. [0, 5],

  63. [0, 4],

  64. [0, 3]],

  65. [[4, 3],

  66. [4, 2],

  67. [4, 1],

  68. [4, 0],

  69. [3, 3],

  70. [3, 2],

  71. [3, 1],

  72. [3, 0],

  73. [2, 3],

  74. [2, 2],

  75. [2, 1],

  76. [2, 0],

  77. [1, 3],

  78. [1, 2],

  79. [1, 1],

  80. [1, 0]],

  81. [[4, 4],

  82. [4, 3],

  83. [4, 2],

  84. [4, 1],

  85. [3, 4],

  86. [3, 3],

  87. [3, 2],

  88. [3, 1],

  89. [2, 4],

  90. [2, 3],

  91. [2, 2],

  92. [2, 1],

  93. [1, 4],

  94. [1, 3],

  95. [1, 2],

  96. [1, 1]],

  97. [[4, 5],

  98. [4, 4],

  99. [4, 3],

  100. [4, 2],

  101. [3, 5],

  102. [3, 4],

  103. [3, 3],

  104. [3, 2],

  105. [2, 5],

  106. [2, 4],

  107. [2, 3],

  108. [2, 2],

  109. [1, 5],

  110. [1, 4],

  111. [1, 3],

  112. [1, 2]],

  113. [[4, 6],

  114. [4, 5],

  115. [4, 4],

  116. [4, 3],

  117. [3, 6],

  118. [3, 5],

  119. [3, 4],

  120. [3, 3],

  121. [2, 6],

  122. [2, 5],

  123. [2, 4],

  124. [2, 3],

  125. [1, 6],

  126. [1, 5],

  127. [1, 4],

  128. [1, 3]],

  129. [[5, 3],

  130. [5, 2],

  131. [5, 1],

  132. [5, 0],

  133. [4, 3],

  134. [4, 2],

  135. [4, 1],

  136. [4, 0],

  137. [3, 3],

  138. [3, 2],

  139. [3, 1],

  140. [3, 0],

  141. [2, 3],

  142. [2, 2],

  143. [2, 1],

  144. [2, 0]],

  145. [[5, 4],

  146. [5, 3],

  147. [5, 2],

  148. [5, 1],

  149. [4, 4],

  150. [4, 3],

  151. [4, 2],

  152. [4, 1],

  153. [3, 4],

  154. [3, 3],

  155. [3, 2],

  156. [3, 1],

  157. [2, 4],

  158. [2, 3],

  159. [2, 2],

  160. [2, 1]],

  161. [[5, 5],

  162. [5, 4],

  163. [5, 3],

  164. [5, 2],

  165. [4, 5],

  166. [4, 4],

  167. [4, 3],

  168. [4, 2],

  169. [3, 5],

  170. [3, 4],

  171. [3, 3],

  172. [3, 2],

  173. [2, 5],

  174. [2, 4],

  175. [2, 3],

  176. [2, 2]],

  177. [[5, 6],

  178. [5, 5],

  179. [5, 4],

  180. [5, 3],

  181. [4, 6],

  182. [4, 5],

  183. [4, 4],

  184. [4, 3],

  185. [3, 6],

  186. [3, 5],

  187. [3, 4],

  188. [3, 3],

  189. [2, 6],

  190. [2, 5],

  191. [2, 4],

  192. [2, 3]],

  193. [[6, 3],

  194. [6, 2],

  195. [6, 1],

  196. [6, 0],

  197. [5, 3],

  198. [5, 2],

  199. [5, 1],

  200. [5, 0],

  201. [4, 3],

  202. [4, 2],

  203. [4, 1],

  204. [4, 0],

  205. [3, 3],

  206. [3, 2],

  207. [3, 1],

  208. [3, 0]],

  209. [[6, 4],

  210. [6, 3],

  211. [6, 2],

  212. [6, 1],

  213. [5, 4],

  214. [5, 3],

  215. [5, 2],

  216. [5, 1],

  217. [4, 4],

  218. [4, 3],

  219. [4, 2],

  220. [4, 1],

  221. [3, 4],

  222. [3, 3],

  223. [3, 2],

  224. [3, 1]],

  225. [[6, 5],

  226. [6, 4],

  227. [6, 3],

  228. [6, 2],

  229. [5, 5],

  230. [5, 4],

  231. [5, 3],

  232. [5, 2],

  233. [4, 5],

  234. [4, 4],

  235. [4, 3],

  236. [4, 2],

  237. [3, 5],

  238. [3, 4],

  239. [3, 3],

  240. [3, 2]],

  241. [[6, 6],

  242. [6, 5],

  243. [6, 4],

  244. [6, 3],

  245. [5, 6],

  246. [5, 5],

  247. [5, 4],

  248. [5, 3],

  249. [4, 6],

  250. [4, 5],

  251. [4, 4],

  252. [4, 3],

  253. [3, 6],

  254. [3, 5],

  255. [3, 4],

  256. [3, 3]]])

 
  1. # 横坐标乘性变换 (*= 7)

  2. relative_coords[:, :, 0] *= 2 * window_size[1] - 1

  3. relative_coords

 
  1. tensor([[[21, 3],

  2. [21, 2],

  3. [21, 1],

  4. [21, 0],

  5. [14, 3],

  6. [14, 2],

  7. [14, 1],

  8. [14, 0],

  9. [ 7, 3],

  10. [ 7, 2],

  11. [ 7, 1],

  12. [ 7, 0],

  13. [ 0, 3],

  14. [ 0, 2],

  15. [ 0, 1],

  16. [ 0, 0]],

  17. [[21, 4],

  18. [21, 3],

  19. [21, 2],

  20. [21, 1],

  21. [14, 4],

  22. [14, 3],

  23. [14, 2],

  24. [14, 1],

  25. [ 7, 4],

  26. [ 7, 3],

  27. [ 7, 2],

  28. [ 7, 1],

  29. [ 0, 4],

  30. [ 0, 3],

  31. [ 0, 2],

  32. [ 0, 1]],

  33. [[21, 5],

  34. [21, 4],

  35. [21, 3],

  36. [21, 2],

  37. [14, 5],

  38. [14, 4],

  39. [14, 3],

  40. [14, 2],

  41. [ 7, 5],

  42. [ 7, 4],

  43. [ 7, 3],

  44. [ 7, 2],

  45. [ 0, 5],

  46. [ 0, 4],

  47. [ 0, 3],

  48. [ 0, 2]],

  49. [[21, 6],

  50. [21, 5],

  51. [21, 4],

  52. [21, 3],

  53. [14, 6],

  54. [14, 5],

  55. [14, 4],

  56. [14, 3],

  57. [ 7, 6],

  58. [ 7, 5],

  59. [ 7, 4],

  60. [ 7, 3],

  61. [ 0, 6],

  62. [ 0, 5],

  63. [ 0, 4],

  64. [ 0, 3]],

  65. [[28, 3],

  66. [28, 2],

  67. [28, 1],

  68. [28, 0],

  69. [21, 3],

  70. [21, 2],

  71. [21, 1],

  72. [21, 0],

  73. [14, 3],

  74. [14, 2],

  75. [14, 1],

  76. [14, 0],

  77. [ 7, 3],

  78. [ 7, 2],

  79. [ 7, 1],

  80. [ 7, 0]],

  81. [[28, 4],

  82. [28, 3],

  83. [28, 2],

  84. [28, 1],

  85. [21, 4],

  86. [21, 3],

  87. [21, 2],

  88. [21, 1],

  89. [14, 4],

  90. [14, 3],

  91. [14, 2],

  92. [14, 1],

  93. [ 7, 4],

  94. [ 7, 3],

  95. [ 7, 2],

  96. [ 7, 1]],

  97. [[28, 5],

  98. [28, 4],

  99. [28, 3],

  100. [28, 2],

  101. [21, 5],

  102. [21, 4],

  103. [21, 3],

  104. [21, 2],

  105. [14, 5],

  106. [14, 4],

  107. [14, 3],

  108. [14, 2],

  109. [ 7, 5],

  110. [ 7, 4],

  111. [ 7, 3],

  112. [ 7, 2]],

  113. [[28, 6],

  114. [28, 5],

  115. [28, 4],

  116. [28, 3],

  117. [21, 6],

  118. [21, 5],

  119. [21, 4],

  120. [21, 3],

  121. [14, 6],

  122. [14, 5],

  123. [14, 4],

  124. [14, 3],

  125. [ 7, 6],

  126. [ 7, 5],

  127. [ 7, 4],

  128. [ 7, 3]],

  129. [[35, 3],

  130. [35, 2],

  131. [35, 1],

  132. [35, 0],

  133. [28, 3],

  134. [28, 2],

  135. [28, 1],

  136. [28, 0],

  137. [21, 3],

  138. [21, 2],

  139. [21, 1],

  140. [21, 0],

  141. [14, 3],

  142. [14, 2],

  143. [14, 1],

  144. [14, 0]],

  145. [[35, 4],

  146. [35, 3],

  147. [35, 2],

  148. [35, 1],

  149. [28, 4],

  150. [28, 3],

  151. [28, 2],

  152. [28, 1],

  153. [21, 4],

  154. [21, 3],

  155. [21, 2],

  156. [21, 1],

  157. [14, 4],

  158. [14, 3],

  159. [14, 2],

  160. [14, 1]],

  161. [[35, 5],

  162. [35, 4],

  163. [35, 3],

  164. [35, 2],

  165. [28, 5],

  166. [28, 4],

  167. [28, 3],

  168. [28, 2],

  169. [21, 5],

  170. [21, 4],

  171. [21, 3],

  172. [21, 2],

  173. [14, 5],

  174. [14, 4],

  175. [14, 3],

  176. [14, 2]],

  177. [[35, 6],

  178. [35, 5],

  179. [35, 4],

  180. [35, 3],

  181. [28, 6],

  182. [28, 5],

  183. [28, 4],

  184. [28, 3],

  185. [21, 6],

  186. [21, 5],

  187. [21, 4],

  188. [21, 3],

  189. [14, 6],

  190. [14, 5],

  191. [14, 4],

  192. [14, 3]],

  193. [[42, 3],

  194. [42, 2],

  195. [42, 1],

  196. [42, 0],

  197. [35, 3],

  198. [35, 2],

  199. [35, 1],

  200. [35, 0],

  201. [28, 3],

  202. [28, 2],

  203. [28, 1],

  204. [28, 0],

  205. [21, 3],

  206. [21, 2],

  207. [21, 1],

  208. [21, 0]],

  209. [[42, 4],

  210. [42, 3],

  211. [42, 2],

  212. [42, 1],

  213. [35, 4],

  214. [35, 3],

  215. [35, 2],

  216. [35, 1],

  217. [28, 4],

  218. [28, 3],

  219. [28, 2],

  220. [28, 1],

  221. [21, 4],

  222. [21, 3],

  223. [21, 2],

  224. [21, 1]],

  225. [[42, 5],

  226. [42, 4],

  227. [42, 3],

  228. [42, 2],

  229. [35, 5],

  230. [35, 4],

  231. [35, 3],

  232. [35, 2],

  233. [28, 5],

  234. [28, 4],

  235. [28, 3],

  236. [28, 2],

  237. [21, 5],

  238. [21, 4],

  239. [21, 3],

  240. [21, 2]],

  241. [[42, 6],

  242. [42, 5],

  243. [42, 4],

  244. [42, 3],

  245. [35, 6],

  246. [35, 5],

  247. [35, 4],

  248. [35, 3],

  249. [28, 6],

  250. [28, 5],

  251. [28, 4],

  252. [28, 3],

  253. [21, 6],

  254. [21, 5],

  255. [21, 4],

  256. [21, 3]]])

 
  1. # 计算 1D 偏移量 (x+y)

  2. relative_position_index = relative_coords.sum(-1) # (wh*ww, wh*ww)

  3. relative_position_index, relative_position_index.shape

  4. # 可见偏移量大小沿主对角线垂直方向扩散、分布

  5. # 16 列与 4×4 个坐标位置一一对应

 
  1. (tensor([[24, 23, 22, 21, 17, 16, 15, 14, 10, 9, 8, 7, 3, 2, 1, 0],

  2. [25, 24, 23, 22, 18, 17, 16, 15, 11, 10, 9, 8, 4, 3, 2, 1],

  3. [26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10, 9, 5, 4, 3, 2],

  4. [27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10, 6, 5, 4, 3],

  5. [31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14, 10, 9, 8, 7],

  6. [32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15, 11, 10, 9, 8],

  7. [33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10, 9],

  8. [34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10],

  9. [38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14],

  10. [39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15],

  11. [40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16],

  12. [41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17],

  13. [45, 44, 43, 42, 38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21],

  14. [46, 45, 44, 43, 39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22],

  15. [47, 46, 45, 44, 40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23],

  16. [48, 47, 46, 45, 41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24]]),

  17. torch.Size([16, 16]))

 
  1. # 设 MHA 的 heads 数为 3

  2. num_heads = 3

  3. relative_position_bias_table = nn.Parameter(

  4. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

  5. relative_position_bias_table, relative_position_bias_table.shape

 
  1. (Parameter containing:

  2. tensor([[0., 0., 0.],

  3. [0., 0., 0.],

  4. [0., 0., 0.],

  5. [0., 0., 0.],

  6. [0., 0., 0.],

  7. [0., 0., 0.],

  8. [0., 0., 0.],

  9. [0., 0., 0.],

  10. [0., 0., 0.],

  11. [0., 0., 0.],

  12. [0., 0., 0.],

  13. [0., 0., 0.],

  14. [0., 0., 0.],

  15. [0., 0., 0.],

  16. [0., 0., 0.],

  17. [0., 0., 0.],

  18. [0., 0., 0.],

  19. [0., 0., 0.],

  20. [0., 0., 0.],

  21. [0., 0., 0.],

  22. [0., 0., 0.],

  23. [0., 0., 0.],

  24. [0., 0., 0.],

  25. [0., 0., 0.],

  26. [0., 0., 0.],

  27. [0., 0., 0.],

  28. [0., 0., 0.],

  29. [0., 0., 0.],

  30. [0., 0., 0.],

  31. [0., 0., 0.],

  32. [0., 0., 0.],

  33. [0., 0., 0.],

  34. [0., 0., 0.],

  35. [0., 0., 0.],

  36. [0., 0., 0.],

  37. [0., 0., 0.],

  38. [0., 0., 0.],

  39. [0., 0., 0.],

  40. [0., 0., 0.],

  41. [0., 0., 0.],

  42. [0., 0., 0.],

  43. [0., 0., 0.],

  44. [0., 0., 0.],

  45. [0., 0., 0.],

  46. [0., 0., 0.],

  47. [0., 0., 0.],

  48. [0., 0., 0.],

  49. [0., 0., 0.],

  50. [0., 0., 0.]], requires_grad=True),

  51. torch.Size([49, 3]))

 
  1. relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(

  2. window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

  3. relative_position_bias, relative_position_bias.shape

 
  1. (tensor([[[0., 0., 0.],

  2. [0., 0., 0.],

  3. [0., 0., 0.],

  4. [0., 0., 0.],

  5. [0., 0., 0.],

  6. [0., 0., 0.],

  7. [0., 0., 0.],

  8. [0., 0., 0.],

  9. [0., 0., 0.],

  10. [0., 0., 0.],

  11. [0., 0., 0.],

  12. [0., 0., 0.],

  13. [0., 0., 0.],

  14. [0., 0., 0.],

  15. [0., 0., 0.],

  16. [0., 0., 0.]],

  17. [[0., 0., 0.],

  18. [0., 0., 0.],

  19. [0., 0., 0.],

  20. [0., 0., 0.],

  21. [0., 0., 0.],

  22. [0., 0., 0.],

  23. [0., 0., 0.],

  24. [0., 0., 0.],

  25. [0., 0., 0.],

  26. [0., 0., 0.],

  27. [0., 0., 0.],

  28. [0., 0., 0.],

  29. [0., 0., 0.],

  30. [0., 0., 0.],

  31. [0., 0., 0.],

  32. [0., 0., 0.]],

  33. [[0., 0., 0.],

  34. [0., 0., 0.],

  35. [0., 0., 0.],

  36. [0., 0., 0.],

  37. [0., 0., 0.],

  38. [0., 0., 0.],

  39. [0., 0., 0.],

  40. [0., 0., 0.],

  41. [0., 0., 0.],

  42. [0., 0., 0.],

  43. [0., 0., 0.],

  44. [0., 0., 0.],

  45. [0., 0., 0.],

  46. [0., 0., 0.],

  47. [0., 0., 0.],

  48. [0., 0., 0.]],

  49. [[0., 0., 0.],

  50. [0., 0., 0.],

  51. [0., 0., 0.],

  52. [0., 0., 0.],

  53. [0., 0., 0.],

  54. [0., 0., 0.],

  55. [0., 0., 0.],

  56. [0., 0., 0.],

  57. [0., 0., 0.],

  58. [0., 0., 0.],

  59. [0., 0., 0.],

  60. [0., 0., 0.],

  61. [0., 0., 0.],

  62. [0., 0., 0.],

  63. [0., 0., 0.],

  64. [0., 0., 0.]],

  65. [[0., 0., 0.],

  66. [0., 0., 0.],

  67. [0., 0., 0.],

  68. [0., 0., 0.],

  69. [0., 0., 0.],

  70. [0., 0., 0.],

  71. [0., 0., 0.],

  72. [0., 0., 0.],

  73. [0., 0., 0.],

  74. [0., 0., 0.],

  75. [0., 0., 0.],

  76. [0., 0., 0.],

  77. [0., 0., 0.],

  78. [0., 0., 0.],

  79. [0., 0., 0.],

  80. [0., 0., 0.]],

  81. [[0., 0., 0.],

  82. [0., 0., 0.],

  83. [0., 0., 0.],

  84. [0., 0., 0.],

  85. [0., 0., 0.],

  86. [0., 0., 0.],

  87. [0., 0., 0.],

  88. [0., 0., 0.],

  89. [0., 0., 0.],

  90. [0., 0., 0.],

  91. [0., 0., 0.],

  92. [0., 0., 0.],

  93. [0., 0., 0.],

  94. [0., 0., 0.],

  95. [0., 0., 0.],

  96. [0., 0., 0.]],

  97. [[0., 0., 0.],

  98. [0., 0., 0.],

  99. [0., 0., 0.],

  100. [0., 0., 0.],

  101. [0., 0., 0.],

  102. [0., 0., 0.],

  103. [0., 0., 0.],

  104. [0., 0., 0.],

  105. [0., 0., 0.],

  106. [0., 0., 0.],

  107. [0., 0., 0.],

  108. [0., 0., 0.],

  109. [0., 0., 0.],

  110. [0., 0., 0.],

  111. [0., 0., 0.],

  112. [0., 0., 0.]],

  113. [[0., 0., 0.],

  114. [0., 0., 0.],

  115. [0., 0., 0.],

  116. [0., 0., 0.],

  117. [0., 0., 0.],

  118. [0., 0., 0.],

  119. [0., 0., 0.],

  120. [0., 0., 0.],

  121. [0., 0., 0.],

  122. [0., 0., 0.],

  123. [0., 0., 0.],

  124. [0., 0., 0.],

  125. [0., 0., 0.],

  126. [0., 0., 0.],

  127. [0., 0., 0.],

  128. [0., 0., 0.]],

  129. [[0., 0., 0.],

  130. [0., 0., 0.],

  131. [0., 0., 0.],

  132. [0., 0., 0.],

  133. [0., 0., 0.],

  134. [0., 0., 0.],

  135. [0., 0., 0.],

  136. [0., 0., 0.],

  137. [0., 0., 0.],

  138. [0., 0., 0.],

  139. [0., 0., 0.],

  140. [0., 0., 0.],

  141. [0., 0., 0.],

  142. [0., 0., 0.],

  143. [0., 0., 0.],

  144. [0., 0., 0.]],

  145. [[0., 0., 0.],

  146. [0., 0., 0.],

  147. [0., 0., 0.],

  148. [0., 0., 0.],

  149. [0., 0., 0.],

  150. [0., 0., 0.],

  151. [0., 0., 0.],

  152. [0., 0., 0.],

  153. [0., 0., 0.],

  154. [0., 0., 0.],

  155. [0., 0., 0.],

  156. [0., 0., 0.],

  157. [0., 0., 0.],

  158. [0., 0., 0.],

  159. [0., 0., 0.],

  160. [0., 0., 0.]],

  161. [[0., 0., 0.],

  162. [0., 0., 0.],

  163. [0., 0., 0.],

  164. [0., 0., 0.],

  165. [0., 0., 0.],

  166. [0., 0., 0.],

  167. [0., 0., 0.],

  168. [0., 0., 0.],

  169. [0., 0., 0.],

  170. [0., 0., 0.],

  171. [0., 0., 0.],

  172. [0., 0., 0.],

  173. [0., 0., 0.],

  174. [0., 0., 0.],

  175. [0., 0., 0.],

  176. [0., 0., 0.]],

  177. [[0., 0., 0.],

  178. [0., 0., 0.],

  179. [0., 0., 0.],

  180. [0., 0., 0.],

  181. [0., 0., 0.],

  182. [0., 0., 0.],

  183. [0., 0., 0.],

  184. [0., 0., 0.],

  185. [0., 0., 0.],

  186. [0., 0., 0.],

  187. [0., 0., 0.],

  188. [0., 0., 0.],

  189. [0., 0., 0.],

  190. [0., 0., 0.],

  191. [0., 0., 0.],

  192. [0., 0., 0.]],

  193. [[0., 0., 0.],

  194. [0., 0., 0.],

  195. [0., 0., 0.],

  196. [0., 0., 0.],

  197. [0., 0., 0.],

  198. [0., 0., 0.],

  199. [0., 0., 0.],

  200. [0., 0., 0.],

  201. [0., 0., 0.],

  202. [0., 0., 0.],

  203. [0., 0., 0.],

  204. [0., 0., 0.],

  205. [0., 0., 0.],

  206. [0., 0., 0.],

  207. [0., 0., 0.],

  208. [0., 0., 0.]],

  209. [[0., 0., 0.],

  210. [0., 0., 0.],

  211. [0., 0., 0.],

  212. [0., 0., 0.],

  213. [0., 0., 0.],

  214. [0., 0., 0.],

  215. [0., 0., 0.],

  216. [0., 0., 0.],

  217. [0., 0., 0.],

  218. [0., 0., 0.],

  219. [0., 0., 0.],

  220. [0., 0., 0.],

  221. [0., 0., 0.],

  222. [0., 0., 0.],

  223. [0., 0., 0.],

  224. [0., 0., 0.]],

  225. [[0., 0., 0.],

  226. [0., 0., 0.],

  227. [0., 0., 0.],

  228. [0., 0., 0.],

  229. [0., 0., 0.],

  230. [0., 0., 0.],

  231. [0., 0., 0.],

  232. [0., 0., 0.],

  233. [0., 0., 0.],

  234. [0., 0., 0.],

  235. [0., 0., 0.],

  236. [0., 0., 0.],

  237. [0., 0., 0.],

  238. [0., 0., 0.],

  239. [0., 0., 0.],

  240. [0., 0., 0.]],

  241. [[0., 0., 0.],

  242. [0., 0., 0.],

  243. [0., 0., 0.],

  244. [0., 0., 0.],

  245. [0., 0., 0.],

  246. [0., 0., 0.],

  247. [0., 0., 0.],

  248. [0., 0., 0.],

  249. [0., 0., 0.],

  250. [0., 0., 0.],

  251. [0., 0., 0.],

  252. [0., 0., 0.],

  253. [0., 0., 0.],

  254. [0., 0., 0.],

  255. [0., 0., 0.],

  256. [0., 0., 0.]]], grad_fn=<ViewBackward>),

  257. torch.Size([16, 16, 3]))

 
  1. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()

  2. relative_position_bias, relative_position_bias.shape

 
  1. (tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  2. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  3. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  4. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  5. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  6. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  7. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  8. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  9. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  10. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  11. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  12. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  13. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  14. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  15. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  16. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

  17. [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  18. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  19. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  20. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  21. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  22. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  23. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  24. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  25. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  26. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  27. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  28. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  29. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  30. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  31. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  32. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

  33. [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  34. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  35. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  36. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  37. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  38. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  39. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  40. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  41. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  42. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  43. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  44. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  45. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  46. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  47. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  48. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],

  49. grad_fn=<CopyBackwards>),

  50. torch.Size([3, 16, 16]))

        以上基本展示了相对位置编码和偏置的生成过程。


3.8 Swin Transformer Block 

3.8.1 Shift Window Attention

        基本的 Attention (W-MSA) 是在每个窗口下计算的,为更好地 和其他窗口交互信息,Swin Transformer 还引入了 移位窗口 (Shifted Window) 操作。

        左边是无重叠的 Basic Window Attention,而右边则是将窗口进行次对角线方向移位的 Shift Window Attention。可见移位后的窗口包含了原相邻窗口的元素,但也随之引入了 Window 个数增多的问题 —— 窗口由 4 个变成了 9 个。

        实现时,通过对特征图移位,并给 Attention 设置 Mask 来间接实现 Shift Window Attention (SW-MSA)。从而,在保持原 Window 数不变的情况下,使最后的计算结果等价。

         在代码中,通过 torch.roll 对特征图移位,如下所示:

         若要执行 reverse cyclic shift 只需将参数 shifts 设为对应的正数。

3.8.2 Attention Mask

        通过合理设置 Mask,可使 Shifted Window Attention (SW-MSA) 在与 Basic Window Attention (W-MSA) 窗口个数相同的情况下,达到等价的计算结果。

        首先,对 Shift Window 后的每个窗口都赋予 index,并执行 roll 操作 (window_size=2, shift_size=1),如下所示:

        在计算 Attention Map 时,希望 仅留下具有相同 index 的 Query 和 Key 的计算结果,而 忽略不同 index 的 Query 和 Key 的计算结果,如下所示:

        注意,Value 和 Query 的 shape 一致 (4×1),以上方法计算的 shape = (4×4) 的 QK 乃至 Attention Map 与 Value 相乘时,依然能够得到正确位置的运算结果,即 (4×4) · (4×1) = (4×1)。

        而若要在原始的四个窗口下得到正确计算结果,则必须给 Attention Map 加入一个 Mask (如上图灰色 patch),相关代码如下:

 
  1. if self.shift_size > 0:

  2. # calculate attention mask for SW-MSA

  3. H, W = self.input_resolution

  4. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1

  5. h_slices = (slice(0, -self.window_size),

  6. slice(-self.window_size, -self.shift_size),

  7. slice(-self.shift_size, None))

  8. w_slices = (slice(0, -self.window_size),

  9. slice(-self.window_size, -self.shift_size),

  10. slice(-self.shift_size, None))

  11. cnt = 0

  12. for h in h_slices:

  13. for w in w_slices:

  14. img_mask[:, h, w, :] = cnt

  15. cnt += 1

  16. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1

  17. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)

  18. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

  19. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        以上图设置,使用上述代码将得到如下的 mask:

 
  1. tensor([[[[[ 0., 0., 0., 0.],

  2. [ 0., 0., 0., 0.],

  3. [ 0., 0., 0., 0.],

  4. [ 0., 0., 0., 0.]]],

  5. [[[ 0., -100., 0., -100.],

  6. [-100., 0., -100., 0.],

  7. [ 0., -100., 0., -100.],

  8. [-100., 0., -100., 0.]]],

  9. [[[ 0., 0., -100., -100.],

  10. [ 0., 0., -100., -100.],

  11. [-100., -100., 0., 0.],

  12. [-100., -100., 0., 0.]]],

  13. [[[ 0., -100., -100., -100.],

  14. [-100., 0., -100., -100.],

  15. [-100., -100., 0., -100.],

  16. [-100., -100., -100., 0.]]]]])

        在 Window Attention 模块的前向过程代码中,包含一段:

 
  1. if mask is not None:

  2. nW = mask.shape[0]

  3. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)

  4. attn = attn.view(-1, self.num_heads, N, N)

  5. attn = self.softmax(attn)

        其将值设为 -100 的 mask 直接加到 Attention Map 上,并在 reshape 后通过 Softmax 近似忽略之。

        最后是 Swin Transformer Block 示意图及其代码:

 
  1. class SwinTransformerBlock(nn.Module):

  2. r""" Swin Transformer Block.

  3. Args:

  4. dim (int): Number of input channels.

  5. input_resolution (tuple[int]): Input resulotion.

  6. num_heads (int): Number of attention heads.

  7. window_size (int): Window size.

  8. shift_size (int): Shift size for SW-MSA.

  9. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.

  10. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

  11. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.

  12. drop (float, optional): Dropout rate. Default: 0.0

  13. attn_drop (float, optional): Attention dropout rate. Default: 0.0

  14. drop_path (float, optional): Stochastic depth rate. Default: 0.0

  15. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU

  16. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

  17. """

  18. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,

  19. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,

  20. act_layer=nn.GELU, norm_layer=nn.LayerNorm):

  21. super().__init__()

  22. self.dim = dim

  23. self.input_resolution = input_resolution

  24. self.num_heads = num_heads

  25. self.window_size = window_size

  26. self.shift_size = shift_size

  27. self.mlp_ratio = mlp_ratio

  28. if min(self.input_resolution) <= self.window_size:

  29. # if window size is larger than input resolution, we don't partition windows

  30. self.shift_size = 0

  31. self.window_size = min(self.input_resolution)

  32. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

  33. self.norm1 = norm_layer(dim)

  34. self.attn = WindowAttention(

  35. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,

  36. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

  37. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

  38. self.norm2 = norm_layer(dim)

  39. mlp_hidden_dim = int(dim * mlp_ratio)

  40. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

  41. ##################### 循环移位局部窗口自注意力 #####################

  42. if self.shift_size > 0:

  43. # calculate attention mask for SW-MSA

  44. H, W = self.input_resolution

  45. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1

  46. h_slices = (slice(0, -self.window_size),

  47. slice(-self.window_size, -self.shift_size),

  48. slice(-self.shift_size, None))

  49. w_slices = (slice(0, -self.window_size),

  50. slice(-self.window_size, -self.shift_size),

  51. slice(-self.shift_size, None))

  52. cnt = 0

  53. for h in h_slices:

  54. for w in w_slices:

  55. img_mask[:, h, w, :] = cnt

  56. cnt += 1

  57. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1

  58. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)

  59. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

  60. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

  61. else:

  62. attn_mask = None

  63. self.register_buffer("attn_mask", attn_mask)

  64. def forward(self, x):

  65. H, W = self.input_resolution

  66. B, L, C = x.shape

  67. assert L == H * W, "input feature has wrong size"

  68. shortcut = x

  69. x = self.norm1(x)

  70. x = x.view(B, H, W, C)

  71. # cyclic shift

  72. if self.shift_size > 0:

  73. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

  74. else:

  75. shifted_x = x

  76. # partition windows

  77. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C

  78. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C

  79. # W-MSA/SW-MSA

  80. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C

  81. # merge windows

  82. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

  83. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C

  84. # reverse cyclic shift

  85. if self.shift_size > 0:

  86. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

  87. else:

  88. x = shifted_x

  89. x = x.view(B, H * W, C)

  90. # FFN

  91. x = shortcut + self.drop_path(x)

  92. x = x + self.drop_path(self.mlp(self.norm2(x)))

  93. return x

  94. def extra_repr(self) -> str:

  95. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \

  96. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

  97. def flops(self):

  98. flops = 0

  99. H, W = self.input_resolution

  100. # norm1

  101. flops += self.dim * H * W

  102. # W-MSA/SW-MSA

  103. nW = H * W / self.window_size / self.window_size

  104. flops += nW * self.attn.flops(self.window_size * self.window_size)

  105. # mlp

  106. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio

  107. # norm2

  108. flops += self.dim * H * W

  109. return flops

         举例模拟说明:

 
  1. def window_partition(x, window_size):

  2. """

  3. Args:

  4. x: (B, H, W, C)

  5. window_size (int): window size - 通常默认 wh = ww = w = 4

  6. Returns:

  7. windows: (num_windows*B, window_size, window_size, C)

  8. """

  9. # (B, H, W, C) -> (B, H//wh, wh, W//ww, ww, C) -> (B, H//wh, W//ww, wh, ww, C) -> ((B*H*W)//(wh*ww), wh, ww, C)

  10. B, H, W, C = x.shape

  11. #print(x.shape)

  12. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)

  13. #print(x.shape)

  14. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)

  15. return windows # ((B*H*W)//(wh*ww), wh, ww, C) = (N, wh, ww, C)

  16. def window_reverse(windows, window_size, H, W):

  17. """

  18. Args:

  19. windows: (num_windows*B, window_size, window_size, C)

  20. window_size (int): Window size - 通常默认 wh = ww = w = 4

  21. H (int): Height of image

  22. W (int): Width of image

  23. Returns:

  24. x: (B, H, W, C)

  25. """

  26. # ((B*H*W)//(wh*ww), w, w, C) -> (B, H//wh, W//ww, wh, ww, C) -> (B, H//wh, wh, W//ww, ww, C) -> (B, H, W, C)

  27. B = int(windows.shape[0] / (H * W / window_size / window_size))

  28. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)

  29. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)

  30. return x

 
  1. batch_size = 1

  2. num_channel = 3

  3. input_resolution = (8, 8)

  4. image = torch.rand(batch_size, num_channel, input_resolution[0], input_resolution[1])

  5. image.shape

torch.Size([1, 3, 8, 8])
 
  1. # local window 的 size 完全由 window_size 和 shift_size 两种长度的组合构成

  2. window_size = 4

  3. shift_size = window_size // 2 # 2

  4. # calculate attention mask for SW-MSA

  5. H, W = input_resolution

  6. img_mask = torch.zeros((1, H, W, 1)) # (1, H, W, 1)

  7. # local window index range

  8. h_slices = (slice(0, -window_size),

  9. slice(-window_size, -shift_size),

  10. slice(-shift_size, None))

  11. w_slices = (slice(0, -window_size),

  12. slice(-window_size, -shift_size),

  13. slice(-shift_size, None))

  14. # 按 local window 标记 idx (即 cnt)

  15. idx = 0

  16. for h in h_slices:

  17. for w in w_slices:

  18. print(h, w, idx)

  19. img_mask[:, h, w, :] = idx

  20. idx += 1

  21. print(f"num local windows: {idx}")

  22. img_mask.shape, img_mask[0, ..., 0]

 
  1. slice(0, -4, None) slice(0, -4, None) 0

  2. slice(0, -4, None) slice(-4, -2, None) 1

  3. slice(0, -4, None) slice(-2, None, None) 2

  4. slice(-4, -2, None) slice(0, -4, None) 3

  5. slice(-4, -2, None) slice(-4, -2, None) 4

  6. slice(-4, -2, None) slice(-2, None, None) 5

  7. slice(-2, None, None) slice(0, -4, None) 6

  8. slice(-2, None, None) slice(-4, -2, None) 7

  9. slice(-2, None, None) slice(-2, None, None) 8

  10. num local windows: 9

  11. (torch.Size([1, 8, 8, 1]),

  12. tensor([[0., 0., 0., 0., 1., 1., 2., 2.],

  13. [0., 0., 0., 0., 1., 1., 2., 2.],

  14. [0., 0., 0., 0., 1., 1., 2., 2.],

  15. [0., 0., 0., 0., 1., 1., 2., 2.],

  16. [3., 3., 3., 3., 4., 4., 5., 5.],

  17. [3., 3., 3., 3., 4., 4., 5., 5.],

  18. [6., 6., 6., 6., 7., 7., 8., 8.],

  19. [6., 6., 6., 6., 7., 7., 8., 8.]]))

 
  1. # (1, 8, 8, 1) = (1, H, W, 1) = (B, H, W, C) -> ((B*H*W)//(wh*ww), wh, ww, C) = (N, wh, ww, C) = (4, 4, 4, 1)

  2. mask_windows = window_partition(img_mask, window_size)

  3. mask_windows.shape, mask_windows[..., 0]

 
  1. (torch.Size([4, 4, 4, 1]),

  2. tensor([[[0., 0., 0., 0.],

  3. [0., 0., 0., 0.],

  4. [0., 0., 0., 0.],

  5. [0., 0., 0., 0.]],

  6. [[1., 1., 2., 2.],

  7. [1., 1., 2., 2.],

  8. [1., 1., 2., 2.],

  9. [1., 1., 2., 2.]],

  10. [[3., 3., 3., 3.],

  11. [3., 3., 3., 3.],

  12. [6., 6., 6., 6.],

  13. [6., 6., 6., 6.]],

  14. [[4., 4., 5., 5.],

  15. [4., 4., 5., 5.],

  16. [7., 7., 8., 8.],

  17. [7., 7., 8., 8.]]]))

 
  1. mask_windows = mask_windows.view(-1, window_size * window_size)

  2. mask_windows.shape, mask_windows

 
  1. (torch.Size([4, 16]),

  2. tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

  3. [1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2.],

  4. [3., 3., 3., 3., 3., 3., 3., 3., 6., 6., 6., 6., 6., 6., 6., 6.],

  5. [4., 4., 5., 5., 4., 4., 5., 5., 7., 7., 8., 8., 7., 7., 8., 8.]]))

 
  1. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

  2. attn_mask.shape, attn_mask[..., 0]

 
  1. (torch.Size([4, 4, 4, 4, 1]),

  2. tensor([[[[ 0., 0., 0., 0.],

  3. [ 0., 0., 0., 0.],

  4. [ 0., 0., 0., 0.],

  5. [ 0., 0., 0., 0.]],

  6. [[ 0., 0., 0., 0.],

  7. [ 0., 0., 0., 0.],

  8. [ 0., 0., 0., 0.],

  9. [ 0., 0., 0., 0.]],

  10. [[ 0., 0., 0., 0.],

  11. [ 0., 0., 0., 0.],

  12. [ 0., 0., 0., 0.],

  13. [ 0., 0., 0., 0.]],

  14. [[ 0., 0., 0., 0.],

  15. [ 0., 0., 0., 0.],

  16. [ 0., 0., 0., 0.],

  17. [ 0., 0., 0., 0.]]],

  18. [[[ 0., 0., 0., 0.],

  19. [ 0., 0., 0., 0.],

  20. [ 0., 0., 0., 0.],

  21. [ 0., 0., 0., 0.]],

  22. [[ 0., 0., 0., 0.],

  23. [ 0., 0., 0., 0.],

  24. [ 0., 0., 0., 0.],

  25. [ 0., 0., 0., 0.]],

  26. [[ 0., 0., 0., 0.],

  27. [ 0., 0., 0., 0.],

  28. [ 0., 0., 0., 0.],

  29. [ 0., 0., 0., 0.]],

  30. [[ 0., 0., 0., 0.],

  31. [ 0., 0., 0., 0.],

  32. [ 0., 0., 0., 0.],

  33. [ 0., 0., 0., 0.]]],

  34. [[[ 0., 0., 0., 0.],

  35. [ 0., 0., 0., 0.],

  36. [ 3., 3., 3., 3.],

  37. [ 3., 3., 3., 3.]],

  38. [[ 0., 0., 0., 0.],

  39. [ 0., 0., 0., 0.],

  40. [ 3., 3., 3., 3.],

  41. [ 3., 3., 3., 3.]],

  42. [[-3., -3., -3., -3.],

  43. [-3., -3., -3., -3.],

  44. [ 0., 0., 0., 0.],

  45. [ 0., 0., 0., 0.]],

  46. [[-3., -3., -3., -3.],

  47. [-3., -3., -3., -3.],

  48. [ 0., 0., 0., 0.],

  49. [ 0., 0., 0., 0.]]],

  50. [[[ 0., 0., 0., 0.],

  51. [ 0., 0., 0., 0.],

  52. [ 3., 3., 3., 3.],

  53. [ 3., 3., 3., 3.]],

  54. [[ 0., 0., 0., 0.],

  55. [ 0., 0., 0., 0.],

  56. [ 3., 3., 3., 3.],

  57. [ 3., 3., 3., 3.]],

  58. [[-3., -3., -3., -3.],

  59. [-3., -3., -3., -3.],

  60. [ 0., 0., 0., 0.],

  61. [ 0., 0., 0., 0.]],

  62. [[-3., -3., -3., -3.],

  63. [-3., -3., -3., -3.],

  64. [ 0., 0., 0., 0.],

  65. [ 0., 0., 0., 0.]]]]))

 
  1. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

  2. attn_mask.shape, attn_mask[..., 0]

 
  1. (torch.Size([4, 4, 4, 4, 1]),

  2. tensor([[[[ 0., 0., 0., 0.],

  3. [ 0., 0., 0., 0.],

  4. [ 0., 0., 0., 0.],

  5. [ 0., 0., 0., 0.]],

  6. [[ 0., 0., 0., 0.],

  7. [ 0., 0., 0., 0.],

  8. [ 0., 0., 0., 0.],

  9. [ 0., 0., 0., 0.]],

  10. [[ 0., 0., 0., 0.],

  11. [ 0., 0., 0., 0.],

  12. [ 0., 0., 0., 0.],

  13. [ 0., 0., 0., 0.]],

  14. [[ 0., 0., 0., 0.],

  15. [ 0., 0., 0., 0.],

  16. [ 0., 0., 0., 0.],

  17. [ 0., 0., 0., 0.]]],

  18. [[[ 0., 0., 0., 0.],

  19. [ 0., 0., 0., 0.],

  20. [ 0., 0., 0., 0.],

  21. [ 0., 0., 0., 0.]],

  22. [[ 0., 0., 0., 0.],

  23. [ 0., 0., 0., 0.],

  24. [ 0., 0., 0., 0.],

  25. [ 0., 0., 0., 0.]],

  26. [[ 0., 0., 0., 0.],

  27. [ 0., 0., 0., 0.],

  28. [ 0., 0., 0., 0.],

  29. [ 0., 0., 0., 0.]],

  30. [[ 0., 0., 0., 0.],

  31. [ 0., 0., 0., 0.],

  32. [ 0., 0., 0., 0.],

  33. [ 0., 0., 0., 0.]]],

  34. [[[ 0., 0., 0., 0.],

  35. [ 0., 0., 0., 0.],

  36. [-100., -100., -100., -100.],

  37. [-100., -100., -100., -100.]],

  38. [[ 0., 0., 0., 0.],

  39. [ 0., 0., 0., 0.],

  40. [-100., -100., -100., -100.],

  41. [-100., -100., -100., -100.]],

  42. [[-100., -100., -100., -100.],

  43. [-100., -100., -100., -100.],

  44. [ 0., 0., 0., 0.],

  45. [ 0., 0., 0., 0.]],

  46. [[-100., -100., -100., -100.],

  47. [-100., -100., -100., -100.],

  48. [ 0., 0., 0., 0.],

  49. [ 0., 0., 0., 0.]]],

  50. [[[ 0., 0., 0., 0.],

  51. [ 0., 0., 0., 0.],

  52. [-100., -100., -100., -100.],

  53. [-100., -100., -100., -100.]],

  54. [[ 0., 0., 0., 0.],

  55. [ 0., 0., 0., 0.],

  56. [-100., -100., -100., -100.],

  57. [-100., -100., -100., -100.]],

  58. [[-100., -100., -100., -100.],

  59. [-100., -100., -100., -100.],

  60. [ 0., 0., 0., 0.],

  61. [ 0., 0., 0., 0.]],

  62. [[-100., -100., -100., -100.],

  63. [-100., -100., -100., -100.],

  64. [ 0., 0., 0., 0.],

  65. [ 0., 0., 0., 0.]]]]))

 
  1. x = image.clone().permute(0, 2, 3, 1) # (B, H, W, C)

  2. B, H, W, C = x.shape

  3. x[0, 0, 0, 0] = 0. # 用于标记

  4. x[0, -1, -1, -1] = 1. # 用于标记

  5. x.shape, x[..., 0], x[..., 1], x[..., 2]

 
  1. (torch.Size([1, 8, 8, 3]),

  2. tensor([[[0.0000, 0.1060, 0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353],

  3. [0.3751, 0.5690, 0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598],

  4. [0.1991, 0.0081, 0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748],

  5. [0.8627, 0.0345, 0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057],

  6. [0.4686, 0.3372, 0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709],

  7. [0.5104, 0.2789, 0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589],

  8. [0.9013, 0.6980, 0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150],

  9. [0.6927, 0.0394, 0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178]]]),

  10. tensor([[[0.6409, 0.0384, 0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691],

  11. [0.7684, 0.0422, 0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495],

  12. [0.3684, 0.9283, 0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926],

  13. [0.1865, 0.5766, 0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666],

  14. [0.0800, 0.3956, 0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635],

  15. [0.9967, 0.1205, 0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302],

  16. [0.5620, 0.9326, 0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879],

  17. [0.9367, 0.9096, 0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822]]]),

  18. tensor([[[0.8273, 0.7008, 0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342],

  19. [0.0062, 0.8495, 0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742],

  20. [0.4363, 0.1852, 0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830],

  21. [0.5139, 0.5424, 0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853],

  22. [0.4835, 0.3571, 0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172],

  23. [0.4750, 0.7820, 0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646],

  24. [0.7525, 0.2686, 0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140],

  25. [0.4540, 0.7455, 0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000]]]))

 
  1. # 可见标记 0.0000 和 1.0000 都在自身所在的 channel 里向左上角循环位移了 shift_size = window_size // 2 = 2 个 pixels

  2. shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) # (B, H, W, C)

  3. shifted_x.shape, shifted_x[..., 0], shifted_x[..., 1], shifted_x[..., 2]

 
  1. (torch.Size([1, 8, 8, 3]),

  2. tensor([[[0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748, 0.1991, 0.0081],

  3. [0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057, 0.8627, 0.0345],

  4. [0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709, 0.4686, 0.3372],

  5. [0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589, 0.5104, 0.2789],

  6. [0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150, 0.9013, 0.6980],

  7. [0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178, 0.6927, 0.0394],

  8. [0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353, 0.0000, 0.1060],

  9. [0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598, 0.3751, 0.5690]]]),

  10. tensor([[[0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926, 0.3684, 0.9283],

  11. [0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666, 0.1865, 0.5766],

  12. [0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635, 0.0800, 0.3956],

  13. [0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302, 0.9967, 0.1205],

  14. [0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879, 0.5620, 0.9326],

  15. [0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822, 0.9367, 0.9096],

  16. [0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691, 0.6409, 0.0384],

  17. [0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495, 0.7684, 0.0422]]]),

  18. tensor([[[0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830, 0.4363, 0.1852],

  19. [0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853, 0.5139, 0.5424],

  20. [0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172, 0.4835, 0.3571],

  21. [0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646, 0.4750, 0.7820],

  22. [0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140, 0.7525, 0.2686],

  23. [0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000, 0.4540, 0.7455],

  24. [0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342, 0.8273, 0.7008],

  25. [0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742, 0.0062, 0.8495]]]))

 
  1. x_windows = window_partition(shifted_x, window_size) # (B*N, wh, ww, C) = (B*(H*W)//(wh*ww), wh, ww, C)

  2. x_windows.shape, x_windows[..., 0], x_windows[..., 1], x_windows[..., 2]

 
  1. (torch.Size([4, 4, 4, 3]),

  2. tensor([[[0.2861, 0.4539, 0.1620, 0.8776],

  3. [0.2435, 0.7224, 0.9310, 0.8621],

  4. [0.3429, 0.6660, 0.7115, 0.8560],

  5. [0.8015, 0.8737, 0.6784, 0.6677]],

  6. [[0.5298, 0.3748, 0.1991, 0.0081],

  7. [0.4113, 0.8057, 0.8627, 0.0345],

  8. [0.7055, 0.4709, 0.4686, 0.3372],

  9. [0.8233, 0.6589, 0.5104, 0.2789]],

  10. [[0.1548, 0.9066, 0.0334, 0.1617],

  11. [0.4180, 0.0387, 0.4488, 0.1339],

  12. [0.1015, 0.2196, 0.3544, 0.8485],

  13. [0.9630, 0.1945, 0.8999, 0.4977]],

  14. [[0.2747, 0.6150, 0.9013, 0.6980],

  15. [0.3340, 0.6178, 0.6927, 0.0394],

  16. [0.3509, 0.7353, 0.0000, 0.1060],

  17. [0.8007, 0.1598, 0.3751, 0.5690]]]),

  18. tensor([[[0.0569, 0.7790, 0.6074, 0.1229],

  19. [0.2497, 0.2391, 0.4254, 0.7249],

  20. [0.7351, 0.8919, 0.1177, 0.7949],

  21. [0.1785, 0.0886, 0.8664, 0.8412]],

  22. [[0.3138, 0.5926, 0.3684, 0.9283],

  23. [0.3116, 0.1666, 0.1865, 0.5766],

  24. [0.0028, 0.2635, 0.0800, 0.3956],

  25. [0.1258, 0.4302, 0.9967, 0.1205]],

  26. [[0.6767, 0.2432, 0.5963, 0.7276],

  27. [0.8327, 0.1795, 0.0361, 0.7189],

  28. [0.7775, 0.3550, 0.6754, 0.9210],

  29. [0.1605, 0.9409, 0.0397, 0.7786]],

  30. [[0.2273, 0.4879, 0.5620, 0.9326],

  31. [0.9292, 0.3822, 0.9367, 0.9096],

  32. [0.0923, 0.0691, 0.6409, 0.0384],

  33. [0.8475, 0.8495, 0.7684, 0.0422]]]),

  34. tensor([[[0.6110, 0.2923, 0.6231, 0.0668],

  35. [0.3008, 0.5251, 0.3518, 0.0882],

  36. [0.2530, 0.8452, 0.2010, 0.3866],

  37. [0.4262, 0.4426, 0.4161, 0.1199]],

  38. [[0.9430, 0.2830, 0.4363, 0.1852],

  39. [0.3335, 0.7853, 0.5139, 0.5424],

  40. [0.5673, 0.1172, 0.4835, 0.3571],

  41. [0.0477, 0.4646, 0.4750, 0.7820]],

  42. [[0.6829, 0.6753, 0.4345, 0.6609],

  43. [0.4464, 0.6749, 0.9152, 0.6936],

  44. [0.2891, 0.1136, 0.4981, 0.2119],

  45. [0.1382, 0.8667, 0.2436, 0.6408]],

  46. [[0.8414, 0.9140, 0.7525, 0.2686],

  47. [0.7035, 1.0000, 0.4540, 0.7455],

  48. [0.8096, 0.6342, 0.8273, 0.7008],

  49. [0.0238, 0.4742, 0.0062, 0.8495]]]))

 
  1. x_windows = x_windows.view(-1, window_size * window_size, C) # (B*N, wh*ww, C) = (B*(H*W)//(wh*ww), wh*ww, C)

  2. x_windows.shape, x_windows[..., 0], x_windows[..., 1], x_windows[..., 2]

 
  1. (torch.Size([4, 16, 3]),

  2. tensor([[0.2861, 0.4539, 0.1620, 0.8776, 0.2435, 0.7224, 0.9310, 0.8621, 0.3429,

  3. 0.6660, 0.7115, 0.8560, 0.8015, 0.8737, 0.6784, 0.6677],

  4. [0.5298, 0.3748, 0.1991, 0.0081, 0.4113, 0.8057, 0.8627, 0.0345, 0.7055,

  5. 0.4709, 0.4686, 0.3372, 0.8233, 0.6589, 0.5104, 0.2789],

  6. [0.1548, 0.9066, 0.0334, 0.1617, 0.4180, 0.0387, 0.4488, 0.1339, 0.1015,

  7. 0.2196, 0.3544, 0.8485, 0.9630, 0.1945, 0.8999, 0.4977],

  8. [0.2747, 0.6150, 0.9013, 0.6980, 0.3340, 0.6178, 0.6927, 0.0394, 0.3509,

  9. 0.7353, 0.0000, 0.1060, 0.8007, 0.1598, 0.3751, 0.5690]]),

  10. tensor([[0.0569, 0.7790, 0.6074, 0.1229, 0.2497, 0.2391, 0.4254, 0.7249, 0.7351,

  11. 0.8919, 0.1177, 0.7949, 0.1785, 0.0886, 0.8664, 0.8412],

  12. [0.3138, 0.5926, 0.3684, 0.9283, 0.3116, 0.1666, 0.1865, 0.5766, 0.0028,

  13. 0.2635, 0.0800, 0.3956, 0.1258, 0.4302, 0.9967, 0.1205],

  14. [0.6767, 0.2432, 0.5963, 0.7276, 0.8327, 0.1795, 0.0361, 0.7189, 0.7775,

  15. 0.3550, 0.6754, 0.9210, 0.1605, 0.9409, 0.0397, 0.7786],

  16. [0.2273, 0.4879, 0.5620, 0.9326, 0.9292, 0.3822, 0.9367, 0.9096, 0.0923,

  17. 0.0691, 0.6409, 0.0384, 0.8475, 0.8495, 0.7684, 0.0422]]),

  18. tensor([[0.6110, 0.2923, 0.6231, 0.0668, 0.3008, 0.5251, 0.3518, 0.0882, 0.2530,

  19. 0.8452, 0.2010, 0.3866, 0.4262, 0.4426, 0.4161, 0.1199],

  20. [0.9430, 0.2830, 0.4363, 0.1852, 0.3335, 0.7853, 0.5139, 0.5424, 0.5673,

  21. 0.1172, 0.4835, 0.3571, 0.0477, 0.4646, 0.4750, 0.7820],

  22. [0.6829, 0.6753, 0.4345, 0.6609, 0.4464, 0.6749, 0.9152, 0.6936, 0.2891,

  23. 0.1136, 0.4981, 0.2119, 0.1382, 0.8667, 0.2436, 0.6408],

  24. [0.8414, 0.9140, 0.7525, 0.2686, 0.7035, 1.0000, 0.4540, 0.7455, 0.8096,

  25. 0.6342, 0.8273, 0.7008, 0.0238, 0.4742, 0.0062, 0.8495]]))

 
  1. # W-MSA/SW-MSA

  2. #attn_windows = self.attn(x_windows, mask=self.attn_mask) # 原操作 (B*N, wh*ww, C) = (B*(H*W)//(wh*ww), wh*ww, C)

  3. attn_windows = x_windows.clone() # 仅用于示范

  4. # merge windows

  5. attn_windows = attn_windows.view(-1, window_size, window_size, C)

  6. attn_windows.shape

torch.Size([4, 4, 4, 3])
 
  1. # reverse cyclic shift

  2. shifted_x = window_reverse(attn_windows, window_size, H, W)

  3. shifted_x.shape, shifted_x[..., 0], shifted_x[..., 1], shifted_x[..., 2]

 
  1. (torch.Size([1, 8, 8, 3]),

  2. tensor([[[0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748, 0.1991, 0.0081],

  3. [0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057, 0.8627, 0.0345],

  4. [0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709, 0.4686, 0.3372],

  5. [0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589, 0.5104, 0.2789],

  6. [0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150, 0.9013, 0.6980],

  7. [0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178, 0.6927, 0.0394],

  8. [0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353, 0.0000, 0.1060],

  9. [0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598, 0.3751, 0.5690]]]),

  10. tensor([[[0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926, 0.3684, 0.9283],

  11. [0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666, 0.1865, 0.5766],

  12. [0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635, 0.0800, 0.3956],

  13. [0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302, 0.9967, 0.1205],

  14. [0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879, 0.5620, 0.9326],

  15. [0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822, 0.9367, 0.9096],

  16. [0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691, 0.6409, 0.0384],

  17. [0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495, 0.7684, 0.0422]]]),

  18. tensor([[[0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830, 0.4363, 0.1852],

  19. [0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853, 0.5139, 0.5424],

  20. [0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172, 0.4835, 0.3571],

  21. [0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646, 0.4750, 0.7820],

  22. [0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140, 0.7525, 0.2686],

  23. [0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000, 0.4540, 0.7455],

  24. [0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342, 0.8273, 0.7008],

  25. [0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742, 0.0062, 0.8495]]]))

 
  1. # 可见标记 0.0000 和 1.0000 都在自身所在的 channel 里向右下角循环位移了 shift_size = window_size // 2 = 2 个 pixels 回去了

  2. x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))

  3. x.shape, x[..., 0], x[..., 1], x[..., 2]

 
  1. (torch.Size([1, 8, 8, 3]),

  2. tensor([[[0.0000, 0.1060, 0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353],

  3. [0.3751, 0.5690, 0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598],

  4. [0.1991, 0.0081, 0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748],

  5. [0.8627, 0.0345, 0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057],

  6. [0.4686, 0.3372, 0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709],

  7. [0.5104, 0.2789, 0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589],

  8. [0.9013, 0.6980, 0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150],

  9. [0.6927, 0.0394, 0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178]]]),

  10. tensor([[[0.6409, 0.0384, 0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691],

  11. [0.7684, 0.0422, 0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495],

  12. [0.3684, 0.9283, 0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926],

  13. [0.1865, 0.5766, 0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666],

  14. [0.0800, 0.3956, 0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635],

  15. [0.9967, 0.1205, 0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302],

  16. [0.5620, 0.9326, 0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879],

  17. [0.9367, 0.9096, 0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822]]]),

  18. tensor([[[0.8273, 0.7008, 0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342],

  19. [0.0062, 0.8495, 0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742],

  20. [0.4363, 0.1852, 0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830],

  21. [0.5139, 0.5424, 0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853],

  22. [0.4835, 0.3571, 0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172],

  23. [0.4750, 0.7820, 0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646],

  24. [0.7525, 0.2686, 0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140],

  25. [0.4540, 0.7455, 0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000]]]))

 
  1. x = x.view(B, H * W, C)

  2. x.shape

torch.Size([1, 64, 3])

3.9 Basic Layer

        Basic Layer 即 Swin Transformer 的各 Stage,包含了若干 Swin Transformer Blocks 及 其他层

        注意,一个 Stage 包含的 Swin Transformer Blocks 的个数必须是 偶数,因为需交替包含一个含有 Window Attention (W-MSA) 的 Layer 和含有 Shifted Window Attention (SW-MSA) 的 Layer。 

 
  1. class BasicLayer(nn.Module):

  2. """ A basic Swin Transformer layer for one stage.

  3. Args:

  4. dim (int): Number of input channels.

  5. input_resolution (tuple[int]): Input resolution.

  6. depth (int): Number of blocks.

  7. num_heads (int): Number of attention heads.

  8. window_size (int): Local window size.

  9. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.

  10. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

  11. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.

  12. drop (float, optional): Dropout rate. Default: 0.0

  13. attn_drop (float, optional): Attention dropout rate. Default: 0.0

  14. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0

  15. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

  16. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None

  17. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.

  18. """

  19. def __init__(self, dim, input_resolution, depth, num_heads, window_size,

  20. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,

  21. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

  22. super().__init__()

  23. self.dim = dim

  24. self.input_resolution = input_resolution

  25. self.depth = depth

  26. self.use_checkpoint = use_checkpoint

  27. # build blocks

  28. self.blocks = nn.ModuleList([

  29. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,

  30. num_heads=num_heads, window_size=window_size,

  31. shift_size=0 if (i % 2 == 0) else window_size // 2,

  32. mlp_ratio=mlp_ratio,

  33. qkv_bias=qkv_bias, qk_scale=qk_scale,

  34. drop=drop, attn_drop=attn_drop,

  35. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,

  36. norm_layer=norm_layer)

  37. for i in range(depth)])

  38. # patch merging layer

  39. if downsample is not None:

  40. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)

  41. else:

  42. self.downsample = None

  43. def forward(self, x):

  44. for blk in self.blocks:

  45. if self.use_checkpoint:

  46. x = checkpoint.checkpoint(blk, x)

  47. else:

  48. x = blk(x)

  49. if self.downsample is not None:

  50. x = self.downsample(x)

  51. return x

  52. def extra_repr(self) -> str:

  53. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

  54. def flops(self):

  55. flops = 0

  56. for blk in self.blocks:

  57. flops += blk.flops()

  58. if self.downsample is not None:

  59. flops += self.downsample.flops()

  60. return flops

原文链接:https://blog.csdn.net/qq_39478403/article/details/120042232

  • 20
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值