Swin Transformer:面向视觉任务的层次化Transformer(代码实现)

Swin Transformer:面向视觉任务的层次化Transformer

随着Transformer在自然语言处理(NLP)领域的巨大成功,研究者们开始探索其在计算机视觉中的应用。传统的Vision Transformer(ViT)虽然在图像分类任务上取得了不错的效果,但其全局自注意力机制导致计算复杂度随图像尺寸呈平方增长,且缺乏对视觉任务中多尺度特征的建模能力。为了解决这些问题,微软亚洲研究院提出了Swin Transformer,一种层次化的视觉Transformer,能够高效处理多种视觉任务,包括图像分类、目标检测和语义分割。本文将面向深度学习研究者,详细介绍Swin Transformer的核心思想、实现方法,并给出相关数学公式。

下文中图片来自于原论文:https://arxiv.org/pdf/2103.14030


Swin Transformer的核心思想

Swin Transformer(Shifted Window Transformer)通过引入层次化设计和移位窗口(Shifted Window)机制,克服了传统ViT的局限性。其核心创新点包括:

  1. 层次化特征表示
    Swin Transformer通过逐步合并相邻图像块(Patch Merging),构建层次化的特征图。这种设计类似于卷积神经网络(CNN)中的特征金字塔,能够适应不同尺度的视觉实体,适用于密集预测任务(如目标检测和语义分割)。

  2. 移位窗口自注意力
    为了降低计算复杂度,Swin Transformer将自注意力限制在局部非重叠窗口内计算。同时,通过在连续层之间移位窗口位置,实现跨窗口的连接,增强模型的建模能力。

  3. 线性计算复杂度
    与ViT的全局自注意力(复杂度为O(N²),N为图像块数量)不同,Swin Transformer的窗口自注意力复杂度与图像尺寸呈线性关系(O(N)),使其能够处理高分辨率图像。

这些特性使Swin Transformer成为一种通用的视觉主干网络,在ImageNet-1K分类(87.3% Top-1准确率)、COCO目标检测(58.7 box AP)和ADE20K语义分割(53.5 mIoU)等任务上显著超越了之前的最佳方法。


Swin Transformer的实现方法

Swin Transformer的架构可以分为四个阶段(Stage),每个阶段通过Patch Merging减少特征图分辨率,并通过Swin Transformer Block提取特征。以下是其具体实现步骤:

  1. 图像分块与嵌入(Patch Partition & Embedding)
    输入RGB图像(尺寸H×W×3)被划分为非重叠的4×4小块(Patch),每个小块包含16个像素。将每个小块展平并通过线性层映射到一个固定维度C(例如96)。初始特征图分辨率为H/4×W/4。

  2. Swin Transformer Block
    每个Swin Transformer Block包含窗口多头自注意力(W-MSA)和移位窗口多头自注意力(SW-MSA),以及多层感知机(MLP)。具体计算过程如下:

    • 窗口划分:将特征图划分为M×M(默认M=7)的非重叠窗口,每个窗口内的自注意力独立计算。
    • 移位窗口:在连续的Block之间,窗口位置按(M/2, M/2)像素移位,连接前一层窗口间的特征。
    • 自注意力计算:在每个窗口内计算多头自注意力,并添加相对位置偏置(Relative Position Bias)以增强空间信息。
  3. Patch Merging
    在每个阶段结束时,将2×2的相邻Patch合并,特征维度从C变为4C,再通过线性层降维至2C,分辨率减半(如H/4×W/4变为H/8×W/8)。

  4. 层次化输出
    经过四个阶段,Swin Transformer生成分辨率分别为H/4×W/4、H/8×W/8、H/16×W/16和H/32×W/32的特征图,类似于CNN的主干网络,可直接用于下游任务。

在这里插入图片描述


数学公式

以下是Swin Transformer中关键部分的数学表示:

  1. 全局自注意力(MSA)复杂度
    传统ViT的全局多头自注意力复杂度为:
    Ω ( MSA ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(\text{MSA}) = 4hwC^2 + 2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
    其中,h×w是图像块数量,C是特征维度。复杂度随hw呈平方增长。

  2. 窗口自注意力(W-MSA)复杂度
    Swin Transformer将自注意力限制在M×M的窗口内,复杂度变为:
    Ω ( W-MSA ) = 4 h w C 2 + 2 M 2 h w C \Omega(\text{W-MSA}) = 4hwC^2 + 2M^2hwC Ω(W-MSA)=4hwC2+2M2hwC
    由于M固定(通常为7),复杂度与hw呈线性关系。

  3. 移位窗口自注意力计算
    Swin Transformer Block的计算分为两步:

    • 第一步(W-MSA):
      z ^ l = W-MSA ( LN ( z l − 1 ) ) + z l − 1 \hat{\mathbf{z}}^l = \text{W-MSA}(\text{LN}(\mathbf{z}^{l-1})) + \mathbf{z}^{l-1} z^l=W-MSA(LN(zl1))+zl1
      z l = MLP ( LN ( z ^ l ) ) + z ^ l \mathbf{z}^l = \text{MLP}(\text{LN}(\hat{\mathbf{z}}^l)) + \hat{\mathbf{z}}^l zl=MLP(LN(z^l))+z^l
    • 第二步(SW-MSA):
      z ^ l + 1 = SW-MSA ( LN ( z l ) ) + z l \hat{\mathbf{z}}^{l+1} = \text{SW-MSA}(\text{LN}(\mathbf{z}^l)) + \mathbf{z}^l z^l+1=SW-MSA(LN(zl))+zl
      z l + 1 = MLP ( LN ( z ^ l + 1 ) ) + z ^ l + 1 \mathbf{z}^{l+1} = \text{MLP}(\text{LN}(\hat{\mathbf{z}}^{l+1})) + \hat{\mathbf{z}}^{l+1} zl+1=MLP(LN(z^l+1))+z^l+1
      其中,LN表示LayerNorm,W-MSA和SW-MSA分别表示常规和移位窗口的多头自注意力。
  4. 带相对位置偏置的自注意力
    窗口内的自注意力计算公式为:
    Attention ( Q , K , V ) = SoftMax ( Q K T / d + B ) V \text{Attention}(Q, K, V) = \text{SoftMax}(QK^T / \sqrt{d} + B)V Attention(Q,K,V)=SoftMax(QKT/d +B)V
    其中,Q、K、V分别是查询、键和值矩阵,d是查询/键维度, B ∈ R ( M 2 × M 2 ) B∈R^{(M²×M²)} BR(M2×M2)是相对位置偏置矩阵,通过参数化一个较小的矩阵( B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat{B}∈R^{(2M-1)×(2M-1)} B^R(2M1)×(2M1))生成。

  5. 高效批计算
    为处理移位窗口产生的不规则窗口,Swin Transformer采用循环移位(Cyclic Shift)方法,将特征图向左上角移位后,通过掩码机制限制自注意力计算在原始子窗口内,避免填充带来的额外计算。


Swin Transformer的优势与实验结果
  1. 高效性
    移位窗口机制显著降低了计算复杂度,使Swin Transformer能够处理高分辨率图像。实验表明,其在V100 GPU上的吞吐量远超基于滑动窗口的自注意力方法。

  2. 优越的性能

    • ImageNet-1K:Swin-L模型在预训练后达到87.3% Top-1准确率。
    • COCO:Swin-L在HTC++框架下实现58.7 box AP和51.1 mask AP。
    • ADE20K:Swin-L取得53.5 mIoU,超越之前的最佳结果3.2 mIoU。
  3. 通用性
    Swin Transformer不仅适用于视觉任务,其层次化设计和移位窗口机制也对全MLP架构(如MLP-Mixer)有益,展现了其普适性。


总结

Swin Transformer通过层次化特征表示和移位窗口自注意力机制,成功将Transformer的强大建模能力引入计算机视觉领域。其线性复杂度、高效实现和优异性能使其成为视觉任务的理想主干网络。对于深度学习研究者来说,Swin Transformer提供了一个值得深入研究的架构,不仅可以直接用于视觉任务,还可能启发NLP与多模态任务的创新。代码和模型已开源(https://github.com/microsoft/Swin-Transformer),欢迎大家尝试与探索!

窗口自注意力(W-MSA)的复杂度公式


1. 全局自注意力(MSA)复杂度(仅考虑乘法)

公式:
Ω ( MSA ) = 4 h w C 2 + ( h w ) 2 C \Omega(\text{MSA}) = 4hwC^2 + (hw)^2C Ω(MSA)=4hwC2+(hw)2C

自注意力机制的基本计算(仅乘法)

输入特征 ( X ∈ R N × C X \in \mathbb{R}^{N \times C} XRN×C),( N = h w N = hw N=hw)(图像块总数),( C C C) 是特征维度。计算过程如下:

  1. 线性变换生成 ( Q Q Q)、( K K K)、( V V V)

    • 查询(Query):( Q = X W Q Q = X W_Q Q=XWQ),( X ∈ R h w × C X \in \mathbb{R}^{hw \times C} XRhw×C),( W Q ∈ R C × C W_Q \in \mathbb{R}^{C \times C} WQRC×C)。
      • 矩阵乘法:( h w × C × C = h w C 2 hw \times C \times C = hwC^2 hw×C×C=hwC2) 次乘法。
    • 键(Key):( K = X W K K = X W_K K=XWK),复杂度同样为 ( h w C 2 hwC^2 hwC2) 次乘法。
    • 值(Value):( V = X W V V = X W_V V=XWV),复杂度为 ( h w C 2 hwC^2 hwC2) 次乘法。
    • 输出投影:( O = MSA ( X ) W O O = \text{MSA}(X) W_O O=MSA(X)WO)(多头自注意力后的线性投影),复杂度为 ( h w C 2 hwC^2 hwC2) 次乘法。
    • 总线性变换乘法次数:
      3 h w C 2 + h w C 2 = 4 h w C 2 3hwC^2 + hwC^2 = 4hwC^2 3hwC2+hwC2=4hwC2
  2. 注意力权重计算

    • 计算 ( Q K T Q K^T QKT):( Q ∈ R h w × C Q \in \mathbb{R}^{hw \times C} QRhw×C),( K T ∈ R C × h w K^T \in \mathbb{R}^{C \times hw} KTRC×hw),结果为 ( h w × h w hw \times hw hw×hw) 的矩阵。
      • 每个元素 (( Q K T ) i , j = ∑ k = 1 C Q i , k ⋅ K j , k Q K^T)_{i,j} = \sum_{k=1}^{C} Q_{i,k} \cdot K_{j,k} QKT)i,j=k=1CQi,kKj,k),需要 ( C C C) 次乘法。
      • 总元素数为 ( h w × h w = ( h w ) 2 hw \times hw = (hw)^2 hw×hw=(hw)2)。
      • 总乘法次数:
        ( h w ) 2 × C = ( h w ) 2 C (hw)^2 \times C = (hw)^2C (hw)2×C=(hw)2C
      • 不考虑加法,因此仅为 ( ( h w ) 2 C (hw)^2C (hw)2C)。
  3. 加权和计算

    • 计算 ( SoftMax ( Q K T ) V \text{SoftMax}(Q K^T) V SoftMax(QKT)V):( SoftMax ( Q K T ) ∈ R h w × h w \text{SoftMax}(Q K^T) \in \mathbb{R}^{hw \times hw} SoftMax(QKT)Rhw×hw),( V ∈ R h w × C V \in \mathbb{R}^{hw \times C} VRhw×C),结果为 ( h w × C hw \times C hw×C)。
      • 每个元素需要 ( h w hw hw) 次乘法:
        ( SoftMax ( Q K T ) V ) i , j = ∑ k = 1 h w SoftMax ( Q K T ) i , k ⋅ V k , j (\text{SoftMax}(Q K^T) V)_{i,j} = \sum_{k=1}^{hw} \text{SoftMax}(Q K^T)_{i,k} \cdot V_{k,j} (SoftMax(QKT)V)i,j=k=1hwSoftMax(QKT)i,kVk,j
      • 总元素数为 ( h w × C hw \times C hw×C)。
      • 总乘法次数:
        h w × C × h w = ( h w ) 2 C hw \times C \times hw = (hw)^2C hw×C×hw=(hw)2C
      • 不考虑加法,仅为 ( ( h w ) 2 C (hw)^2C (hw)2C)。
总复杂度(仅乘法)

Ω ( MSA ) = 4 h w C 2 + ( h w ) 2 C + ( h w ) 2 C = 4 h w C 2 + 2 ( h w ) 2 C \Omega(\text{MSA}) = 4hwC^2 + (hw)^2C + (hw)^2C = 4hwC^2 + 2(hw)^2C Ω(MSA)=4hwC2+(hw)2C+(hw)2C=4hwC2+2(hw)2C
这与论文中的公式一致:
Ω ( MSA ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(\text{MSA}) = 4hwC^2 + 2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C

  • ( 4 h w C 2 4hwC^2 4hwC2) 来自线性变换(( Q Q Q)、( K K K)、( V V V) 和输出投影)。
  • ( 2 ( h w ) 2 C 2(hw)^2C 2(hw)2C) 来自 ( Q K T Q K^T QKT) 和 ( SoftMax ( Q K T ) V \text{SoftMax}(Q K^T) V SoftMax(QKT)V) 的乘法。

2. 窗口自注意力(W-MSA)复杂度(仅考虑乘法)

公式:
Ω ( W-MSA ) = 4 h w C 2 + M 2 h w C \Omega(\text{W-MSA}) = 4hwC^2 + M^2hwC Ω(W-MSA)=4hwC2+M2hwC

窗口自注意力的计算(仅乘法)

Swin Transformer将特征图划分为 ( h w M 2 \frac{hw}{M^2} M2hw) 个非重叠的 ( M × M M \times M M×M) 窗口,每个窗口内独立计算自注意力。

  1. 线性变换

    • 总Patch数量为 ( h w hw hw)。
    • ( Q = X W Q Q = X W_Q Q=XWQ)、( K = X W K K = X W_K K=XWK)、( V = X W V V = X W_V V=XWV)、输出投影 ( O = W-MSA ( X ) W O O = \text{W-MSA}(X) W_O O=W-MSA(X)WO),每个变换复杂度为 ( h w C 2 hwC^2 hwC2)。
    • 总乘法次数:
      4 h w C 2 4hwC^2 4hwC2
    • 与全局自注意力相同,因为线性变换在所有Patch上统一执行,与窗口划分无关。
  2. 注意力权重计算

    • 每个窗口内有 ( M 2 M^2 M2) 个Patch。
    • ( Q K T Q K^T QKT):( Q ∈ R M 2 × C Q \in \mathbb{R}^{M^2 \times C} QRM2×C),( K T ∈ R C × M 2 K^T \in \mathbb{R}^{C \times M^2} KTRC×M2),结果为 ( M 2 × M 2 M^2 \times M^2 M2×M2)。
      • 每个元素需要 ( C C C) 次乘法。
      • 总元素数为 ( M 2 × M 2 = M 4 M^2 \times M^2 = M^4 M2×M2=M4)。
      • 单个窗口的乘法次数:
        M 4 × C = M 4 C M^4 \times C = M^4C M4×C=M4C
    • 窗口数量为 ( h w M 2 \frac{hw}{M^2} M2hw)。
    • 总乘法次数:
      h w M 2 × M 4 C = M 2 h w C \frac{hw}{M^2} \times M^4C = M^2hwC M2hw×M4C=M2hwC
  3. 加权和计算

    • ( SoftMax ( Q K T ) V \text{SoftMax}(Q K^T) V SoftMax(QKT)V):( SoftMax ( Q K T ) ∈ R M 2 × M 2 \text{SoftMax}(Q K^T) \in \mathbb{R}^{M^2 \times M^2} SoftMax(QKT)RM2×M2),( V ∈ R M 2 × C V \in \mathbb{R}^{M^2 \times C} VRM2×C),结果为 ( M 2 × C M^2 \times C M2×C)。
      • 每个元素需要 ( M 2 M^2 M2) 次乘法。
      • 总元素数为 ( M 2 × C M^2 \times C M2×C)。
      • 单个窗口的乘法次数:
        M 2 × C × M 2 = M 4 C M^2 \times C \times M^2 = M^4C M2×C×M2=M4C
    • 总乘法次数:
      h w M 2 × M 4 C = M 2 h w C \frac{hw}{M^2} \times M^4C = M^2hwC M2hw×M4C=M2hwC
总复杂度(仅乘法)

Ω ( W-MSA ) = 4 h w C 2 + M 2 h w C + M 2 h w C = 4 h w C 2 + 2 M 2 h w C \Omega(\text{W-MSA}) = 4hwC^2 + M^2hwC + M^2hwC = 4hwC^2 + 2M^2hwC Ω(W-MSA)=4hwC2+M2hwC+M2hwC=4hwC2+2M2hwC
这与论文中的公式一致:
Ω ( W-MSA ) = 4 h w C 2 + 2 M 2 h w C \Omega(\text{W-MSA}) = 4hwC^2 + 2M^2hwC Ω(W-MSA)=4hwC2+2M2hwC

  • ( 4 h w C 2 4hwC^2 4hwC2) 来自线性变换。
  • ( 2 M 2 h w C 2M^2hwC 2M2hwC) 来自每个窗口的 ( Q K T Q K^T QKT) 和 ( SoftMax ( Q K T ) V \text{SoftMax}(Q K^T) V SoftMax(QKT)V) 的乘法。

对比与总结(仅乘法)

  1. 全局自注意力(MSA)

    • 复杂度:( 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C)。
    • ( 2 ( h w ) 2 C 2(hw)^2C 2(hw)2C) 主导增长,随 ( h w hw hw) 呈平方增长。
  2. 窗口自注意力(W-MSA)

    • 复杂度:( 4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2hwC 4hwC2+2M2hwC)。
    • ( M M M) 固定(如7),( 2 M 2 h w C 2M^2hwC 2M2hwC) 随 ( h w hw hw) 线性增长。
关键差异
  • 不考虑加法时,( Q K T Q K^T QKT) 和 ( SoftMax ( Q K T ) V \text{SoftMax}(Q K^T) V SoftMax(QKT)V) 的乘法开销各占一半(分别贡献 ( ( h w ) 2 C (hw)^2C (hw)2C) 或 ( M 2 h w C M^2hwC M2hwC))。
  • 论文公式包含了两部分的乘法,表明其复杂度分析是全面的。
  • 线性复杂度(W-MSA)的核心在于窗口化限制了每次自注意力的范围,使乘法次数从 ( ( h w ) 2 (hw)^2 (hw)2) 降为 ( M 2 h w M^2hw M2hw)。

Patch Merging 的详细解释

在Swin Transformer中,Patch Merging 是一种关键机制,用于构建层次化的特征表示。通过逐步合并相邻的图像块(Patch),它实现了空间分辨率的降低和特征维度的调整,从而形成类似于卷积神经网络(CNN)中特征金字塔的多尺度特征图。这种设计不仅降低了计算复杂度,还增强了模型对不同尺度视觉实体的建模能力,使其适用于密集预测任务(如目标检测和语义分割)。下面我将详细解释Patch Merging的原理、实现步骤及其作用。


Patch Merging 的核心思想

Patch Merging 的灵感来源于CNN中常用的池化操作(如最大池化或平均池化),但它在Transformer架构中被重新设计,以适应图像块(Patch)的表示方式。其目标是:

  1. 降低空间分辨率:通过合并相邻的Patch,将特征图的空间尺寸缩小(例如,从 ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4) 降至 ( H / 8 × W / 8 H/8 \times W/8 H/8×W/8))。
  2. 增加特征维度:合并后的Patch会包含更多信息,因此特征维度需要相应调整(通常翻倍),以保留丰富的语义信息。
  3. 构建层次化表示:通过多阶段的Patch Merging,生成一系列分辨率逐渐降低、语义信息逐渐增强的特征图,形成层次化的结构。

这种机制使得Swin Transformer能够像CNN一样,从低级局部特征逐步过渡到高级全局特征,非常适合需要多尺度特征的任务。


Patch Merging 的实现步骤

Patch Merging 是在Swin Transformer的每个阶段(Stage)之间执行的操作。假设输入特征图的分辨率为 ( H × W H \times W H×W)(例如,经过初始Patch Partition后的 ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4)),特征维度为 ( C C C)(例如96)。以下是具体的实现步骤:

  1. 分组相邻Patch

    • 将特征图划分为 ( 2 × 2 2 \times 2 2×2) 的小块网格。即,每 ( 2 × 2 2 \times 2 2×2) 个相邻Patch(总共4个Patch)被视为一个组。
    • 假设输入特征图为 ( R H × W × C \mathbb{R}^{H \times W \times C} RH×W×C),分组后,每个 ( 2 × 2 2 \times 2 2×2) 的小块包含4个Patch,每个Patch的维度为 ( C C C)。
    • 分组的结果是,空间分辨率变为原来的 ( 1 / 2 1/2 1/2),即从 ( H × W H \times W H×W) 降至 ( H / 2 × W / 2 H/2 \times W/2 H/2×W/2)。
  2. 展平与拼接

    • 将每个 ( 2 × 2 2 \times 2 2×2) 小块中的4个Patch(每个维度为 ( C C C))沿通道维度拼接起来。
    • 具体操作:将4个Patch的特征向量(每个为 ( R C \mathbb{R}^{C} RC))展平并拼接成一个新的特征向量,维度变为 ( 4 C 4C 4C)。
    • 例如,若输入特征图为 ( H × W × C H \times W \times C H×W×C),则拼接后的特征图为 ( H / 2 × W / 2 × 4 C H/2 \times W/2 \times 4C H/2×W/2×4C)。
  3. 线性降维

    • 拼接后的特征维度 ( 4 C 4C 4C) 通常过高,直接使用会增加后续计算的复杂度。因此,通过一个全连接层(线性层)将维度从 ( 4 C 4C 4C) 降至 ( 2 C 2C 2C)(或其他预设值,通常是下一阶段的输入维度)。
    • 数学上,设输入为 ( X ∈ R H / 2 × W / 2 × 4 C X \in \mathbb{R}^{H/2 \times W/2 \times 4C} XRH/2×W/2×4C),线性层权重为 ( W ∈ R 4 C × 2 C W \in \mathbb{R}^{4C \times 2C} WR4C×2C),输出为:
      X ′ = X ⋅ W , X ′ ∈ R H / 2 × W / 2 × 2 C X' = X \cdot W, \quad X' \in \mathbb{R}^{H/2 \times W/2 \times 2C} X=XW,XRH/2×W/2×2C
    • 这个线性层不仅降低维度,还可以通过学习调整特征的表达能力。
  4. 结果

    • 经过Patch Merging后,特征图的分辨率从 ( H × W H \times W H×W) 变为 ( H / 2 × W / 2 H/2 \times W/2 H/2×W/2),特征维度从 ( C C C) 变为 ( 2 C 2C 2C)。
    • 例如,初始特征图为 ( H / 4 × W / 4 × 96 H/4 \times W/4 \times 96 H/4×W/4×96),经过Patch Merging后变为 ( H / 8 × W / 8 × 192 H/8 \times W/8 \times 192 H/8×W/8×192)。

Patch Merging 在 Swin Transformer 中的应用

Swin Transformer 包含四个阶段(Stage 1 到 Stage 4),每个阶段由若干 Swin Transformer Block 组成,并在阶段之间执行 Patch Merging。具体流程如下:

  • Stage 1:输入特征图为 ( H / 4 × W / 4 × C H/4 \times W/4 \times C H/4×W/4×C)(例如 ( 56 × 56 × 96 56 \times 56 \times 96 56×56×96)),经过若干Block处理后,输出仍为 ( H / 4 × W / 4 × C H/4 \times W/4 \times C H/4×W/4×C)。
  • Patch Merging:将 ( H / 4 × W / 4 × C H/4 \times W/4 \times C H/4×W/4×C) 合并为 ( H / 8 × W / 8 × 2 C H/8 \times W/8 \times 2C H/8×W/8×2C)(例如 ( 28 × 28 × 192 28 \times 28 \times 192 28×28×192))。
  • Stage 2:以 ( H / 8 × W / 8 × 2 C H/8 \times W/8 \times 2C H/8×W/8×2C) 作为输入,继续处理。
  • 依此类推,最终到 Stage 4 输出 ( H / 32 × W / 32 × 8 C H/32 \times W/32 \times 8C H/32×W/32×8C)(例如 ( 7 × 7 × 768 7 \times 7 \times 768 7×7×768))。

这种逐层合并的过程生成了分辨率逐渐降低的特征图:

  • ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4)(细粒度特征)。
  • ( H / 8 × W / 8 H/8 \times W/8 H/8×W/8)。
  • ( H / 16 × W / 16 H/16 \times W/16 H/16×W/16)。
  • ( H / 32 × W / 32 H/32 \times W/32 H/32×W/32)(粗粒度特征)。

Patch Merging 的计算复杂度

假设输入特征图为 ( H × W × C H \times W \times C H×W×C):

  • 拼接:仅涉及重排操作,无额外乘法。
  • 线性降维
    • 输入为 ( H / 2 × W / 2 × 4 C H/2 \times W/2 \times 4C H/2×W/2×4C),权重为 ( 4 C × 2 C 4C \times 2C 4C×2C)。
    • 每个输出元素需要 ( 4 C 4C 4C) 次乘法,总输出元素数为 ( H / 2 × W / 2 × 2 C H/2 \times W/2 \times 2C H/2×W/2×2C)。
    • 总乘法次数:
      ( H / 2 × W / 2 ) × 2 C × 4 C = H ⋅ W ⋅ 4 C 2 (H/2 \times W/2) \times 2C \times 4C = H \cdot W \cdot 4C^2 (H/2×W/2)×2C×4C=HW4C2
    • 复杂度为 ( O ( H W ⋅ C 2 ) O(HW \cdot C^2) O(HWC2)),与特征图大小呈线性关系。

Patch Merging 的作用与优势
  1. 层次化特征表示

    • 类似CNN中的下采样(如池化或步幅卷积),Patch Merging 通过降低分辨率,逐步提取更高层次的语义信息。
    • 低分辨率特征图(如 ( H / 32 × W / 32 H/32 \times W/32 H/32×W/32))适合捕捉全局上下文,而高分辨率特征图(如 ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4))保留细节信息。
  2. 多尺度适应性

    • 生成的多尺度特征图(类似于CNN的特征金字塔)直接适配下游任务。例如:
      • 目标检测:需要检测不同大小的目标。
      • 语义分割:需要像素级预测,融合多尺度特征。
  3. 计算效率

    • 空间分辨率减半后,后续 Swin Transformer Block 处理的Patch数量减少(例如从 ( H ⋅ W H \cdot W HW) 到 ( H ⋅ W / 4 H \cdot W / 4 HW/4)),降低了窗口自注意力的计算量。
    • 即使特征维度增加(( C → 2 C C \to 2C C2C)),整体复杂度仍可控,因为窗口大小 ( M M M) 固定,自注意力复杂度与Patch数量呈线性关系。
  4. 与CNN的相似性

    • Patch Merging 的设计借鉴了CNN的空间层次化思想,使得Swin Transformer在视觉任务中表现出与CNN相似的归纳偏置(Inductive Bias),如局部性与平移不变性。

与初始 Patch Partition 的对比
  • Patch Partition(初始分块):

    • 在模型开头,将输入图像(( H × W × 3 H \times W \times 3 H×W×3))划分为 ( 4 × 4 4 \times 4 4×4) 的Patch(每个Patch 16像素),通过线性层映射到维度 ( C C C)(如96)。
    • 输出:( H / 4 × W / 4 × C H/4 \times W/4 \times C H/4×W/4×C)。
    • 作用:将图像转换为Token序列。
  • Patch Merging

    • 在中间阶段,将已有特征图的 ( 2 × 2 2 \times 2 2×2) Patch 合并,维度从 ( C C C) 升至 ( 2 C 2C 2C)。
    • 输出:( H / 8 × W / 8 × 2 C H/8 \times W/8 \times 2C H/8×W/8×2C)(依阶段递推)。
    • 作用:构建多尺度特征。

总结

Patch Merging 是Swin Transformer实现层次化特征表示的核心操作,通过合并相邻Patch、降低分辨率并调整特征维度,形成了类似CNN特征金字塔的结构。它的设计既高效又灵活,使得Swin Transformer能够在图像分类、目标检测和语义分割等任务中表现出色。对于密集预测任务,Patch Merging 提供的多尺度特征尤为关键,弥补了传统ViT缺乏层次化的不足。

在 Patch Merging 的过程中,空间分辨率从 (H \times W) 变为 (H/2 \times W/2),即变为原来的 (1/2)(严格来说是面积变为 (1/4),但线性维度变为 (1/2))。这个变化的原因与分组和合并的机制直接相关。下面我将详细解释为什么空间分辨率会变成 (1/2)。


空间分辨率变为 (1/2) 的原因

空间分辨率通常指的是特征图在宽度(( W W W))和高度(( H H H))上的尺寸。在 Patch Merging 中,通过将 ( 2 × 2 2 \times 2 2×2) 的相邻 Patch 合并为一个新的单元,特征图的空间维度在每个方向上被压缩了一半。让我们逐步分析这个过程:

1. 输入特征图的结构
  • 假设输入特征图为 ( R H × W × C \mathbb{R}^{H \times W \times C} RH×W×C),其中:
    • ( H H H) 是高度方向的 Patch 数量。
    • ( W W W) 是宽度方向的 Patch 数量。
    • ( C C C) 是每个 Patch 的特征维度。
  • 这意味着特征图包含 ( H × W H \times W H×W) 个 Patch,每个 Patch 是一个 ( C C C) 维向量。
2. 分组为 ( 2 × 2 2 \times 2 2×2) 小块
  • Patch Merging 将特征图划分为 ( 2 × 2 2 \times 2 2×2) 的小块网格。也就是说,每 ( 2 × 2 2 \times 2 2×2) 个相邻 Patch(总共 4 个 Patch)被视为一个组。
  • 为了实现这种分组:
    • 高度方向:每 2 个 Patch 合并为 1 个新单元,因此高度从 ( H H H) 变为 ( H / 2 H/2 H/2)。
    • 宽度方向:每 2 个 Patch 合并为 1 个新单元,因此宽度从 ( W W W) 变为 ( W / 2 W/2 W/2)。
  • 前提是 ( H H H) 和 ( W W W) 必须是 2 的整数倍(在 Swin Transformer 中,输入尺寸通常经过设计满足这一条件,例如初始 Patch Partition 后为 ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4))。
3. 合并后空间维度的变化
  • 每个 ( 2 × 2 2 \times 2 2×2) 的小块(包含 4 个 Patch)被合并为一个新的 Patch。
  • 合并后:
    • 高度方向的 Patch 数量从 ( H H H) 减少到 ( H / 2 H/2 H/2)。
    • 宽度方向的 Patch 数量从 ( W W W) 减少到 ( W / 2 W/2 W/2)。
  • 因此,新的特征图空间分辨率为 ( H / 2 × W / 2 H/2 \times W/2 H/2×W/2)。
  • 从数量上看,总的 Patch 数量从 ( H × W H \times W H×W) 变为:
    ( H / 2 ) × ( W / 2 ) = H × W 4 (H/2) \times (W/2) = \frac{H \times W}{4} (H/2)×(W/2)=4H×W
    • 空间分辨率的线性维度(高度和宽度)变为原来的 ( 1 / 2 1/2 1/2)。
    • 空间面积(总 Patch 数)变为原来的 ( 1 / 4 1/4 1/4)。
4. 直观理解:类比图像缩放
  • 可以把 Patch Merging 想象成图像的下采样过程。例如:
    • 如果输入是一个 ( 4 × 4 4 \times 4 4×4) 的特征图(16 个 Patch),划分为 ( 2 × 2 2 \times 2 2×2) 的小块后:
      • 总共有 ( 4 / 2 × 4 / 2 = 2 × 2 = 4 4 / 2 \times 4 / 2 = 2 \times 2 = 4 4/2×4/2=2×2=4) 个小块。
      • 每个小块包含 4 个 Patch,合并为 1 个新 Patch。
      • 输出特征图变为 ( 2 × 2 2 \times 2 2×2)。
    • 高度和宽度都从 4 变为 2,即 ( 1 / 2 1/2 1/2)。
  • 这种合并本质上是对空间维度的一次“池化”操作,只不过它不直接取平均或最大值,而是通过拼接和线性变换保留信息。
5. 数学上的验证
  • 输入特征图:( H × W × C H \times W \times C H×W×C)。
  • 分组后:
    • 高度方向有 ( H / 2 H/2 H/2) 个 ( 2 × 2 2 \times 2 2×2) 小块。
    • 宽度方向有 ( W / 2 W/2 W/2) 个 ( 2 × 2 2 \times 2 2×2) 小块。
    • 总小块数为 ( H / 2 × W / 2 H/2 \times W/2 H/2×W/2)。
  • 每个小块合并为 1 个新 Patch,因此输出特征图的空间分辨率为 ( H / 2 × W / 2 H/2 \times W/2 H/2×W/2)。

为什么说“空间分辨率变为 (1/2)”?

  • 在计算机视觉中,“空间分辨率”通常指特征图的高度和宽度(线性维度),而不是总面积。
  • Patch Merging 使 ( H H H) 变为 ( H / 2 H/2 H/2),( W W W) 变为 ( W / 2 W/2 W/2),因此每个维度的分辨率变为原来的 ( 1 / 2 1/2 1/2)。
  • 如果从面积(总 Patch 数)的角度看,分辨率变为 ( 1 / 4 1/4 1/4)(因为 ( H × W H \times W H×W) 变为 ( H × W 4 \frac{H \times W}{4} 4H×W)),但习惯上我们关注线性维度的变化,因此说是“空间分辨率变为 ( 1 / 2 1/2 1/2)”。

与特征维度的关系

  • 空间分辨率减半的同时,Patch Merging 将每个 ( 2 × 2 2 \times 2 2×2) 小块的 4 个 ( C C C) 维 Patch 拼接为一个 ( 4 C 4C 4C) 维向量,再通过线性层降维为 ( 2 C 2C 2C)。
  • 输出特征图变为 ( R H / 2 × W / 2 × 2 C \mathbb{R}^{H/2 \times W/2 \times 2C} RH/2×W/2×2C),空间变小但特征维度增加,保留了信息容量。

总结

空间分辨率变为 ( 1 / 2 1/2 1/2) 是因为 Patch Merging 将 ( 2 × 2 2 \times 2 2×2) 的相邻 Patch 合并为一个新 Patch,导致高度和宽度方向的 Patch 数量各减少一半(从 ( H × W H \times W H×W) 到 ( H / 2 × W / 2 H/2 \times W/2 H/2×W/2))。这种设计模仿了CNN的下采样,实现了层次化特征提取的核心步骤。

详解Figure 1

在这里插入图片描述
Figure 1 展示了 Swin Transformer 和 ViT(Vision Transformer)的架构对比,并标注了 ( 4 × 4\times 4×)、( 8 × 8\times 8×)、( 16 × 16\times 16×)。这些符号在图中表示特征图空间分辨率相对于原始输入图像的缩减倍数。下面将详细解释这些符号的含义,并结合 Swin Transformer 和 ViT 的架构特点进行分析。


Figure 1 的背景

Figure 1 包含两部分:

  • (a) Swin Transformer (ours):展示了 Swin Transformer 的层次化架构,通过逐步合并图像块(Patch Merging)构建多尺度特征图,并标注了 ( 4 × 4\times 4×)、( 8 × 8\times 8×)、( 16 × 16\times 16×)。
  • (b) ViT:展示了传统 Vision Transformer 的架构,特征图保持单一低分辨率(标注为 ( 16 × 16\times 16×))。

这些标注与特征图的空间分辨率直接相关,表示特征图的高度和宽度相对于原始输入图像的缩减倍数。


( 4 × 4\times 4×)、( 8 × 8\times 8×)、( 16 × 16\times 16×) 的含义

1. Swin Transformer (Figure 1a)

Swin Transformer 采用层次化设计,通过 Patch Partition 和 Patch Merging 逐步降低空间分辨率。假设输入图像尺寸为 ( H × W H \times W H×W)(例如 ( 224 × 224 224 \times 224 224×224),ImageNet 标准尺寸):

  • 初始 Patch Partition

    • 输入图像被划分为 ( 4 × 4 4 \times 4 4×4) 的 Patch(每个 Patch 包含 16 个像素)。
    • 空间分辨率从 ( H × W H \times W H×W) 变为 ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4)。
    • 例如,( 224 × 224 224 \times 224 224×224) 变为 ( 56 × 56 56 \times 56 56×56)。
    • 这对应于图中的 ( 4 × 4\times 4×) 标注,表示每个维度(高度和宽度)缩小了 4 倍。
  • Stage 1 到 Stage 4 的 Patch Merging

    • Swin Transformer 包含 4 个阶段(Stage),每个阶段末尾执行 Patch Merging,将 ( 2 × 2 2 \times 2 2×2) 的相邻 Patch 合并为 1 个 Patch,空间分辨率减半。
    • Stage 1 输出:分辨率为 ($H/4 \times W/4)(例如 ( 56 × 56 56 \times 56 56×56)),标注为 ( 4 × 4\times 4×)。
    • Stage 2 输出(经过 1 次 Patch Merging):分辨率为 ( H / 8 × W / 8 H/8 \times W/8 H/8×W/8)(例如 ( 28 × 28 28 \times 28 28×28)),标注为 ( 8 × 8\times 8×)(高度和宽度各缩小 8 倍)。
    • Stage 3 输出(经过 2 次 Patch Merging):分辨率为 ( H / 16 × W / 16 H/16 \times W/16 H/16×W/16)(例如 ( 14 × 14 14 \times 14 14×14)),标注为 ( 16 × 16\times 16×)(高度和宽度各缩小 16 倍)。
    • Stage 4 输出(经过 3 次 Patch Merging):分辨率为 ( H / 32 × W / 32 H/32 \times W/32 H/32×W/32)(例如 ( 7 × 7 7 \times 7 7×7)),可能未在图中标注,但对应 ( 32 × 32\times 32×)。

总结

  • ( 4 × 4\times 4×):特征图分辨率为 ( H / 4 × W / 4 H/4 \times W/4 H/4×W/4)(Stage 1)。
  • ( 8 × 8\times 8×):特征图分辨率为 ( H / 8 × W / 8 H/8 \times W/8 H/8×W/8)(Stage 2)。
  • ( 16 × 16\times 16×):特征图分辨率为 ( H / 16 × W / 16 H/16 \times W/16 H/16×W/16)(Stage 3)。

这些倍数表示空间分辨率在每个阶段的缩减程度,体现了 Swin Transformer 的层次化设计,能够生成多尺度特征图(类似于 CNN 的特征金字塔),适用于分类、检测和分割等任务。

2. ViT (Figure 1b)

ViT(Vision Transformer)采用非层次化的设计,特征图分辨率在整个网络中保持不变:

  • 初始 Patch Partition

    • ViT 通常将输入图像划分为较大的 Patch,例如 ( 16 × 16 16 \times 16 16×16)(而不是 Swin Transformer 的 ( 4 × 4 4 \times 4 4×4))。
    • 空间分辨率从 ( H × W H \times W H×W) 直接变为 ( H / 16 × W / 16 H/16 \times W/16 H/16×W/16)。
    • 例如,( 224 × 224 224 \times 224 224×224) 变为 ( 14 × 14 14 \times 14 14×14)。
    • 这对应于图中的 ( 16 × 16\times 16×) 标注,表示每个维度缩小了 16 倍。
  • 后续层

    • ViT 不执行 Patch Merging,特征图分辨率始终保持 ( H / 16 × W / 16 H/16 \times W/16 H/16×W/16)(例如 ( 14 × 14 14 \times 14 14×14))。
    • 因此,图中所有特征图都标注为 ( 16 × 16\times 16×),表示空间分辨率没有变化。

总结

  • ( 16 × 16\times 16×):ViT 的特征图分辨率为 ( H / 16 × W / 16 H/16 \times W/16 H/16×W/16),从头到尾保持不变。

为什么有这些倍数?

这些倍数反映了空间分辨率的变化,与 Patch 分割和合并的步幅直接相关:

  • Swin Transformer

    • 初始 ( 4 × 4 4 \times 4 4×4) Patch 分割导致 ( 4 × 4\times 4×) 缩减。
    • 每次 Patch Merging 使分辨率再减半(步幅 2),因此依次为 ( 4 × → 8 × → 16 × → 32 × 4\times \to 8\times \to 16\times \to 32\times 4×8×16×32×)。
    • 这种逐步缩减形成了层次化特征表示,适合多尺度任务。
  • ViT

    • 直接使用 ( 16 × 16 16 \times 16 16×16) Patch 分割,导致 ( 16 × 16\times 16×) 缩减。
    • 没有层次化设计,分辨率固定,适合图像分类但不利于密集预测任务(如检测和分割)。

为什么 Swin Transformer 使用这种设计?

  • Swin Transformer 的层次化设计(( 4 × → 8 × → 16 × 4\times \to 8\times \to 16\times 4×8×16×)):

    • 通过逐步降低分辨率,生成多尺度特征图,类似于 CNN 的特征金字塔。
    • 窗口自注意力(W-MSA)限制在局部窗口内,计算复杂度从 ( O ( ( h w ) 2 ) O((hw)^2) O((hw)2)) 降为 ( O ( h w ) O(hw) O(hw)),更适合高分辨率图像。
    • 适用于多种任务:图像分类、目标检测、语义分割。
  • ViT 的单一分辨率设计(固定 ( 16 × 16\times 16×)):

    • 全局自注意力导致计算复杂度高(( O ( ( h w ) 2 ) O((hw)^2) O((hw)2))),难以处理高分辨率图像。
    • 缺乏多尺度特征,限制了其在密集预测任务中的表现。

移位窗口自注意力详细解释

移位窗口自注意力(Shifted Window Multi-Head Self-Attention, SW-MSA)是 Swin Transformer 的核心创新之一,它通过在连续的 Swin Transformer Block 之间移位窗口位置,增强了跨窗口的特征交互,同时保持计算效率。以下将详细解析移位窗口的实现原理、具体操作步骤、以及其在 Swin Transformer 中的作用。


移位窗口自注意力的核心思想

Swin Transformer 的窗口自注意力(W-MSA)将特征图划分为非重叠的局部窗口(例如 ( M × M M \times M M×M),默认 ( M = 7 M=7 M=7)),并在每个窗口内独立计算多头自注意力。这种设计显著降低了计算复杂度(从全局自注意力的 ( O ( ( h w ) 2 ) O((hw)^2) O((hw)2)) 降为 ( O ( h w ) O(hw) O(hw))),但也带来了一个问题:窗口之间缺乏直接的特征交互,导致模型对全局信息的建模能力受限。

为了解决这个问题,Swin Transformer 引入了 移位窗口(Shifted Window) 机制:

  • 在连续的 Swin Transformer Block 之间,窗口位置会发生移位(Shift),使得原本属于不同窗口的 Patch 可以在下一层中属于同一个窗口,从而实现跨窗口的特征交互。
  • 这种移位操作在保持局部计算效率的同时,间接实现了全局连接的效果。

移位窗口的具体实现

Swin Transformer Block 的计算分为两步:W-MSA(常规窗口自注意力)和 SW-MSA(移位窗口自注意力)。移位窗口的实现主要发生在 SW-MSA 步骤中。以下是详细的实现步骤:

1. 常规窗口划分(W-MSA)
  • 输入特征图:假设输入特征图为 ( R H × W × C \mathbb{R}^{H \times W \times C} RH×W×C),其中 ( H × W H \times W H×W) 是空间分辨率(Patch 数量),( C C C) 是特征维度。
  • 窗口划分
    • 将特征图划分为非重叠的 ( M × M M \times M M×M) 窗口(例如 ( M = 7 M=7 M=7))。
    • 假设 ( H H H) 和 ( W W W) 是 ( M M M) 的整数倍(如果不是,会通过填充补齐),则窗口数量为:
      窗口数量 = H M × W M \text{窗口数量} = \frac{H}{M} \times \frac{W}{M} 窗口数量=MH×MW
    • 每个窗口包含 ( M × M M \times M M×M) 个 Patch(例如 ( 7 × 7 = 49 7 \times 7 = 49 7×7=49) 个 Patch)。
  • 自注意力计算
    • 在每个窗口内独立计算多头自注意力(W-MSA),公式为:
      z ^ l = W-MSA ( LN ( z l − 1 ) ) + z l − 1 \hat{\mathbf{z}}^l = \text{W-MSA}(\text{LN}(\mathbf{z}^{l-1})) + \mathbf{z}^{l-1} z^l=W-MSA(LN(zl1))+zl1
      z l = MLP ( LN ( z ^ l ) ) + z ^ l \mathbf{z}^l = \text{MLP}(\text{LN}(\hat{\mathbf{z}}^l)) + \hat{\mathbf{z}}^l zl=MLP(LN(z^l))+z^l
    • 每个窗口内的自注意力只关注窗口内的 Patch,窗口之间没有直接交互。
2. 移位窗口划分(SW-MSA)

在下一层(第 ( l + 1 l+1 l+1) 层),Swin Transformer 将窗口位置移位,然后再计算自注意力(SW-MSA)。具体步骤如下:

  • 窗口移位

    • 将整个特征图向 左上角 移位 (( M / 2 , M / 2 ) M/2, M/2) M/2,M/2)) 个 Patch。
    • 例如,若 ( M = 7 M=7 M=7),则移位 ( ( M / 2 , M / 2 ) = ( 3 , 3 ) (M/2, M/2) = (3, 3) (M/2,M/2)=(3,3)),即向上和向左各移 3 个 Patch。
    • 移位后,特征图的边界会“溢出”。为了保持特征图大小不变,Swin Transformer 采用 循环移位(Cyclic Shift)
      • 溢出的部分会被循环移动到特征图的另一侧。例如,向左移 3 个 Patch 后,最左边的 3 列会被移动到最右边;向上移 3 个 Patch 后,最上面的 3 行会被移动到最底部。
    • 循环移位后的特征图仍为 ( H × W × C H \times W \times C H×W×C),但 Patch 的空间位置发生了变化。
  • 新的窗口划分

    • 在移位后的特征图上,重新划分为 ( M × M M \times M M×M) 的非重叠窗口。
    • 由于移位,新的窗口会包含一些原本属于不同窗口的 Patch。例如:
      • 在 W-MSA 中,Patch A 和 Patch B 可能属于不同的窗口,无法直接交互。
      • 在 SW-MSA 中,经过移位后,Patch A 和 Patch B 可能被划分到同一个窗口内,从而可以通过自注意力机制交互。
  • 问题:不规则窗口

    • 移位后,新的窗口划分可能导致一些窗口跨越特征图的边界(由于循环移位,边界处的 Patch 可能来自特征图的另一侧)。
    • 例如,一个新窗口可能包含:
      • 左上角的 Patch(来自原始特征图的右下角)。
      • 右下角的 Patch(来自原始特征图的左上角)。
    • 这种跨越会导致不自然的交互(例如,原本空间上很远的 Patch 被错误地放在同一个窗口内)。
  • 解决方案:掩码机制(Masked Attention)

    • 为了避免不自然的跨边界交互,Swin Transformer 引入了掩码机制:
      • 在计算自注意力时,添加一个注意力掩码(Attention Mask),禁止某些 Patch 之间的交互。
      • 具体来说,对于一个窗口内的 Patch,如果它们在原始特征图(移位前)不属于同一个 ( M × M M \times M M×M) 区域,则在注意力计算中将它们之间的权重设为负无穷(即在 SoftMax 后权重为 0)。
    • 这样,移位后的窗口虽然包含了循环移位带来的 Patch,但通过掩码机制,仍然只计算原始空间上邻近的 Patch 之间的注意力。
  • 自注意力计算

    • 在移位后的窗口上计算多头自注意力(SW-MSA),公式为:
      z ^ l + 1 = SW-MSA ( LN ( z l ) ) + z l \hat{\mathbf{z}}^{l+1} = \text{SW-MSA}(\text{LN}(\mathbf{z}^l)) + \mathbf{z}^l z^l+1=SW-MSA(LN(zl))+zl
      z l + 1 = MLP ( LN ( z ^ l + 1 ) ) + z ^ l + 1 \mathbf{z}^{l+1} = \text{MLP}(\text{LN}(\hat{\mathbf{z}}^{l+1})) + \hat{\mathbf{z}}^{l+1} zl+1=MLP(LN(z^l+1))+z^l+1
    • 每个窗口内的自注意力仍然是局部的,但由于窗口位置的移位,Patch 之间的交互范围变大了。
3. 高效实现:循环移位与批处理
  • 循环移位
    • 循环移位操作可以通过简单的张量重排实现,无需额外的计算开销。
    • 例如,PyTorch 中可以通过 torch.roll 操作实现:
      shifted_feature = torch.roll(feature, shifts=(-M//2, -M//2), dims=(1, 2))
      
  • 批处理
    • 移位后,窗口数量可能增加(因为边界处可能产生较小的窗口,例如 ( M × 3 M \times 3 M×3) 或 ( 3 × 3 3 \times 3 3×3))。
    • Swin Transformer 通过填充(Padding)将所有窗口统一为 ( M × M M \times M M×M),然后使用批处理(Batch Computation)并行计算所有窗口的自注意力。
    • 掩码机制确保填充部分的 Patch 不影响注意力计算。

移位窗口的作用

  1. 跨窗口连接

    • 移位窗口机制使得原本属于不同窗口的 Patch 可以在下一层中属于同一个窗口,从而实现跨窗口的特征交互。
    • 例如,Patch A 和 Patch B 在 W-MSA 中无法交互,但在 SW-MSA 中可能被划分到同一个窗口内,间接建立了连接。
  2. 全局建模能力

    • 通过在多层之间交替使用 W-MSA 和 SW-MSA,Swin Transformer 可以在多层后实现全局连接。
    • 理论上,经过足够多的层,每个 Patch 都可以通过移位窗口与特征图中的任意其他 Patch 建立间接联系,类似于全局自注意力,但计算复杂度仍然是线性的。
  3. 保持计算效率

    • 移位窗口并没有增加自注意力的计算范围(仍然是 ( M × M M \times M M×M) 窗口内计算),因此复杂度保持为 (O(hw))。
    • 相比 ViT 的全局自注意力(( O ( ( h w ) 2 ) O((hw)^2) O((hw)2))),Swin Transformer 的设计更高效。

移位窗口的数学表示

在 SW-MSA 中,自注意力计算与 W-MSA 类似,但需要考虑移位和掩码:
Attention ( Q , K , V ) = SoftMax ( Q K T d + B + M ) V \text{Attention}(Q, K, V) = \text{SoftMax}\left(\frac{QK^T}{\sqrt{d}} + B + M\right)V Attention(Q,K,V)=SoftMax(d QKT+B+M)V

  • ( Q , K , V Q, K, V Q,K,V):查询、键、值矩阵,基于移位后的窗口计算。
  • ( B B B):相对位置偏置(Relative Position Bias),增强空间信息。
  • ( M M M):注意力掩码(Attention Mask),对于不应该交互的 Patch 对,( M i , j = − ∞ M_{i,j} = -\infty Mi,j=),确保 SoftMax 后权重为 0。

移位窗口的直观示例

假设特征图为 ( 8 × 8 8 \times 8 8×8)(64 个 Patch),窗口大小 ( M = 4 M=4 M=4):

  • W-MSA

    • 划分为 ( 8 / 4 × 8 / 4 = 2 × 2 = 4 8/4 \times 8/4 = 2 \times 2 = 4 8/4×8/4=2×2=4) 个窗口,每个窗口 ( 4 × 4 4 \times 4 4×4)。
    • 窗口 1:左上角 ((0,0)) 到 ((3,3))。
    • 窗口 2:右上角 ((0,4)) 到 ((3,7)),等等。
    • 每个窗口内的 16 个 Patch 独立计算自注意力。
  • SW-MSA

    • 移位 ( ( M / 2 , M / 2 ) = ( 2 , 2 ) (M/2, M/2) = (2, 2) (M/2,M/2)=(2,2)),特征图向左上移动 2 个 Patch。
    • 移位后,重新划分为 ( 4 × 4 4 \times 4 4×4) 窗口:
      • 新窗口 1:可能包含原始窗口 1、2、3、4 的部分 Patch。
    • 掩码机制确保新窗口内的 Patch 只与原始空间上邻近的 Patch 交互。

总结

移位窗口自注意力(SW-MSA)通过以下步骤实现:

  1. 循环移位:将特征图向左上移位 ((M/2, M/2)),溢出部分循环到另一侧。
  2. 窗口划分:在移位后的特征图上重新划分为 (M \times M) 窗口。
  3. 掩码机制:通过注意力掩码避免不自然的跨边界交互。
  4. 自注意力计算:在每个窗口内计算多头自注意力。

移位窗口机制在保持计算效率(线性复杂度)的同时,增强了跨窗口的特征交互能力,间接实现了全局建模的效果。这是 Swin Transformer 能够在多种视觉任务中表现出色的关键原因之一。

向左上角移位,那起始位置在哪?右下角吗?

关于 Swin Transformer 中移位窗口(Shifted Window)机制的问题,“向左上角移位,那起始位置在哪?右下角吗?”详细解释移位窗口的实现细节,特别是移位的方向和起始位置的含义。


移位窗口的核心操作

在这里插入图片描述

在 Swin Transformer 中,移位窗口(Shifted Window)机制通过将整个特征图向 左上角 移位 ( ( M / 2 , M / 2 ) (M/2, M/2) (M/2,M/2)) 个 Patch 来实现跨窗口的特征交互。以下是具体分析:

1. 移位方向:向左上角
  • “向左上角移位” 的含义:

    • “向左”:特征图在宽度方向(水平方向,( W W W) 轴)向左移动。
    • “向上”:特征图在高度方向(垂直方向,( H H H) 轴)向上移动。
    • 移位的距离为 ( ( M / 2 , M / 2 ) (M/2, M/2) (M/2,M/2))。例如,若窗口大小 ( M = 7 M=7 M=7),则 ( ( M / 2 , M / 2 ) = ( 3 , 3 ) (M/2, M/2) = (3, 3) (M/2,M/2)=(3,3)),即向左和向上各移动 3 个 Patch。
  • 在二维特征图的坐标系中:

    • 假设特征图为 ( R H × W × C \mathbb{R}^{H \times W \times C} RH×W×C),坐标 ( ( i , j ) (i, j) (i,j)) 表示第 ( i i i) 行、第 ( j j j) 列的 Patch。
    • 坐标 ( ( 0 , 0 ) (0, 0) (0,0)) 通常定义为特征图的 左上角(这是图像处理中的常见约定:左上角为原点,( i i i) 向下增加,( j j j) 向右增加)。
  • 移位操作

    • 向左移 3 个 Patch:( j → j − 3 j \to j - 3 jj3)。
    • 向上移 3 个 Patch:( i → i − 3 i \to i - 3 ii3)。
    • 因此,坐标 ( ( i , j ) (i, j) (i,j)) 变为 ( ( i − 3 , j − 3 ) (i - 3, j - 3) (i3,j3))。
2. 起始位置和循环移位

“向左上角移位,那起始位置在哪?右下角吗?”这个问题涉及到移位的实现方式和边界处理。

  • 起始位置

    • 这里的“起始位置”并不是指某个特定的 Patch 作为起点,而是指整个特征图作为一个整体进行移位。
    • 在图像处理中,移位操作通常是对整个张量(特征图)进行操作,而不是从某个特定位置开始。
    • 因此,没有一个明确的“起始位置”(如右下角)。移位是从特征图的当前状态(即所有 Patch 的当前位置)开始,向左上角移动。
  • 循环移位(Cyclic Shift)

    • 由于特征图是有限的(尺寸为 ( H × W H \times W H×W)),向左上角移位会导致部分 Patch “溢出”边界。
    • 例如:
      • 向左移 3 个 Patch:最左边的 3 列(( j = 0 , 1 , 2 j = 0, 1, 2 j=0,1,2))会溢出。
      • 向上移 3 个 Patch:最上面的 3 行(( i = 0 , 1 , 2 i = 0, 1, 2 i=0,1,2))会溢出。
    • 为了保持特征图大小不变,Swin Transformer 采用 循环移位
      • 溢出的部分会被移动到特征图的另一侧。
      • 向左移 3 个 Patch:最左边的 3 列被移动到最右边(( j = W − 3 , W − 2 , W − 1 j = W-3, W-2, W-1 j=W3,W2,W1))。
      • 向上移 3 个 Patch:最上面的 3 行被移动到最底部(( i = H − 3 , H − 2 , H − 1 i = H-3, H-2, H-1 i=H3,H2,H1))。
  • 循环移位的效果

    • 循环移位后,特征图的尺寸仍然是 ( H × W × C H \times W \times C H×W×C),但每个 Patch 的位置发生了变化。
    • 例如,原本在 (( 0 , 0 0, 0 0,0)) 的 Patch(左上角)移位后可能出现在 ( ( H − 3 , W − 3 ) (H-3, W-3) (H3,W3))(接近右下角)。
3. “右下角”是否是起始位置?
  • “右下角”可能是对循环移位的直观理解:如果向左上角移位,是否意味着右下角的 Patch 会被移动到左上角?

  • 答案是:部分正确,但不完全是“起始位置”的概念

    • 循环移位确实会导致右下角的 Patch 被移动到靠近左上角的位置,但这并不是因为右下角是“起始位置”,而是因为循环移位的边界处理。
    • 具体来说:
      • 向左移 3 个 Patch:最右边的 3 列(( j = W − 3 , W − 2 , W − 1 j = W-3, W-2, W-1 j=W3,W2,W1))会被移动到最左边(( j = 0 , 1 , 2 j = 0, 1, 2 j=0,1,2))。
      • 向上移 3 个 Patch:最底部的 3 行(( i = H − 3 , H − 2 , H − 1 i = H-3, H-2, H-1 i=H3,H2,H1))会被移动到最顶部(( i = 0 , 1 , 2 i = 0, 1, 2 i=0,1,2))。
    • 因此,右下角的 Patch(例如 ( ( H − 1 , W − 1 ) (H-1, W-1) (H1,W1)))在移位后可能会出现在 (( H − 4 , W − 4 ) H-4, W-4) H4,W4)),具体位置取决于 ( H H H)、( W W W) 和移位距离。
  • 没有固定的“起始位置”

    • 移位操作是对整个特征图的全局操作,所有 Patch 同时移动 (( M / 2 , M / 2 ) M/2, M/2) M/2,M/2)) 的距离。
    • 循环移位只是为了处理边界,确保特征图大小不变,而不是从某个特定位置(如右下角)开始移动。
4. 直观示例

假设特征图为 ( 8 × 8 8 \times 8 8×8)(64 个 Patch),窗口大小 ( M = 4 M=4 M=4),移位 ( ( M / 2 , M / 2 ) = ( 2 , 2 ) (M/2, M/2) = (2, 2) (M/2,M/2)=(2,2)):

  • 原始特征图

    • 坐标 ((0, 0)) 是左上角,((7, 7)) 是右下角。
    • 特征图表示为:
      [ P 0 , 0 P 0 , 1 ⋯ P 0 , 7 P 1 , 0 P 1 , 1 ⋯ P 1 , 7 ⋮ ⋮ ⋱ ⋮ P 7 , 0 P 7 , 1 ⋯ P 7 , 7 ] \begin{bmatrix} P_{0,0} & P_{0,1} & \cdots & P_{0,7} \\ P_{1,0} & P_{1,1} & \cdots & P_{1,7} \\ \vdots & \vdots & \ddots & \vdots \\ P_{7,0} & P_{7,1} & \cdots & P_{7,7} \end{bmatrix} P0,0P1,0P7,0P0,1P1,1P7,1P0,7P1,7P7,7
  • 向左上角移位 (2, 2)

    • 向上移 2 个 Patch:第 0、1 行移出,移动到最底部(第 6、7 行)。
    • 向左移 2 个 Patch:第 0、1 列移出,移动到最右边(第 6、7 列)。
    • 移位后的特征图:
      • 原始 ((0, 0)) 移到 ((6, 6))。
      • 原始 ((2, 2)) 移到 ((0, 0))。
      • 原始 ((7, 7)) 移到 ((5, 5))。
    • 新特征图:
      [ P 2 , 2 P 2 , 3 ⋯ P 2 , 7 P 2 , 0 P 2 , 1 P 3 , 2 P 3 , 3 ⋯ P 3 , 7 P 3 , 0 P 3 , 1 ⋮ ⋮ ⋱ ⋮ ⋮ ⋮ P 1 , 2 P 1 , 3 ⋯ P 1 , 7 P 1 , 0 P 1 , 1 ] \begin{bmatrix} P_{2,2} & P_{2,3} & \cdots & P_{2,7} & P_{2,0} & P_{2,1} \\ P_{3,2} & P_{3,3} & \cdots & P_{3,7} & P_{3,0} & P_{3,1} \\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots \\ P_{1,2} & P_{1,3} & \cdots & P_{1,7} & P_{1,0} & P_{1,1} \end{bmatrix} P2,2P3,2P1,2P2,3P3,3P1,3P2,7P3,7P1,7P2,0P3,0P1,0P2,1P3,1P1,1
  • 右下角的移动

    • 右下角的 Patch(如 ((7, 7)))移位后出现在 ((5, 5)),并不是直接移动到左上角。
    • 左上角 ((0, 0)) 反而被移动到接近右下角的位置 ((6, 6))。
    • 因此,移位并不是从右下角“开始”,而是所有 Patch 同时移动,循环移位处理了边界。
5. 代码实现(参考)

在 PyTorch 中,循环移位可以通过 torch.roll 实现:

# 特征图:(batch_size, H, W, C)
feature = torch.randn(1, 8, 8, 96)  # 示例:8x8 特征图
M = 4
shift = M // 2  # 移位距离:2
# 向左上角移位 (负号表示向左和向上)
shifted_feature = torch.roll(feature, shifts=(-shift, -shift), dims=(1, 2))
  • shifts=(-2, -2):表示在高度和宽度方向上分别向左和向上移 2 个 Patch。
  • 循环移位自动处理边界,溢出的部分被移动到另一侧。

总结

  • 向左上角移位:指的是整个特征图在高度和宽度方向上分别向上和向左移动 (( M / 2 , M / 2 ) M/2, M/2) M/2,M/2)) 个 Patch。
  • 起始位置:移位操作没有固定的“起始位置”,所有 Patch 同时移动。循环移位确保特征图大小不变,溢出的部分被移动到另一侧。
  • 右下角的误解:右下角的 Patch 确实会因为循环移位而移动到靠近左上角的位置,但这不是“起始位置”的概念,而是循环移位的自然结果。
  • 实际效果:移位后,特征图的 Patch 位置重新排列,新的窗口划分会包含原本属于不同窗口的 Patch,从而实现跨窗口的特征交互。

相对位置偏置被加到注意力计算的公式

在 Swin Transformer 的移位窗口自注意力(SW-MSA)中,相对位置偏置(Relative Position Bias, 记为 ( B B B))被加到注意力计算的公式中,而不是在输入时与嵌入(Embedding)相加。这种设计有其特定的原因,涉及到 Transformer 架构的特性、相对位置编码的作用,以及 Swin Transformer 的窗口化设计。以下将详细解析为什么相对位置偏置加在注意力计算中,而不是与输入嵌入相加。


1. 相对位置偏置的作用

在 Swin Transformer 中,相对位置偏置 ( B B B) 的作用是增强模型对空间位置信息的感知能力。以下是其核心作用:

  • 引入空间信息

    • 标准的 Transformer 自注意力机制是基于内容(Content-Based)的,计算 ( Q K T Q K^T QKT) 时只依赖于 ( Q Q Q) 和 ( K K K) 的特征内容,而不考虑 Patch 之间的空间位置关系。
    • 在视觉任务中,空间位置信息非常重要(例如,两个 Patch 的相对距离会影响它们的语义关系)。Swin Transformer 通过引入相对位置偏置 ( B B B),让模型显式地考虑 Patch 之间的空间关系。
  • 窗口化设计的需求

    • Swin Transformer 将特征图划分为局部窗口(例如 ( M × M M \times M M×M)),并在每个窗口内计算自注意力。
    • 在一个窗口内,Patch 之间的相对位置(例如 ( ( i 1 , j 1 ) (i_1, j_1) (i1,j1)) 和 ( ( i 2 , j 2 ) (i_2, j_2) (i2,j2)) 的相对坐标 ( ( i 1 − i 2 , j 1 − j 2 ) (i_1 - i_2, j_1 - j_2) (i1i2,j1j2)))对注意力权重有重要影响。相对位置偏置 ( B B B) 用来建模这种空间关系。
  • 增强局部性

    • 相对位置偏置使得模型更倾向于关注空间上更接近的 Patch(因为它们的 ( B B B) 值通常会使注意力权重更高),这与视觉任务中的局部性假设(Locality Bias)一致。

2. 为什么加在注意力计算中(( Q K T d + B \frac{QK^T}{\sqrt{d}} + B d QKT+B))?

Swin Transformer 将相对位置偏置 (B) 加到注意力计算的公式中,而不是在输入时与嵌入相加,主要有以下原因:

(1) 直接影响注意力权重
  • 注意力计算的核心

    • 自注意力的计算公式为:
      Attention ( Q , K , V ) = SoftMax ( Q K T d + B ) V \text{Attention}(Q, K, V) = \text{SoftMax}\left(\frac{QK^T}{\sqrt{d}} + B\right)V Attention(Q,K,V)=SoftMax(d QKT+B)V
    • 其中,( Q K T d \frac{QK^T}{\sqrt{d}} d QKT) 表示基于内容的相似性(Content-Based Similarity),而 ( B B B) 表示基于位置的偏置(Position-Based Bias)。
    • ( B B B) 直接加到 ( Q K T d \frac{QK^T}{\sqrt{d}} d QKT) 上,意味着它会直接影响注意力权重 ( SoftMax ( ⋅ ) \text{SoftMax}(\cdot) SoftMax()) 的分布。
  • 为什么要影响注意力权重?

    • 注意力权重决定了每个 Patch 对其他 Patch 的贡献程度。
    • 如果两个 Patch 在空间上更接近,它们的相对位置偏置 ( B i , j B_{i,j} Bi,j) 通常会更大(通过学习得到),从而增加它们的注意力权重。
    • 这种方式让模型在计算注意力时同时考虑内容相似性(( Q K T Q K^T QKT))和空间关系(( B B B)),更符合视觉任务的需求。
  • 对比输入嵌入

    • 如果将位置信息加到输入嵌入(例如 ( X + PosEmbedding X + \text{PosEmbedding} X+PosEmbedding)),位置信息会通过 ( Q = X W Q Q = X W_Q Q=XWQ)、( K = X W K K = X W_K K=XWK) 间接影响 ( Q K T Q K^T QKT)。这种方式虽然也能引入位置信息,但影响是间接的,且可能会被 ( W Q W_Q WQ)、( W K W_K WK) 的线性变换削弱。
    • 直接将 ( B B B) 加到 ( Q K T d \frac{QK^T}{\sqrt{d}} d QKT) 上,可以更显式地控制注意力权重,确保位置信息在注意力计算中起到关键作用。
(2) 相对位置偏置的动态性
  • 相对位置偏置是动态学习的

    • 在 Swin Transformer 中,( B B B) 不是固定的位置嵌入,而是一个可学习的参数矩阵。
    • 具体实现:
      • 每个窗口内有 ( M × M M \times M M×M) 个 Patch(例如 ( M = 7 M=7 M=7),则有 (49) 个 Patch)。
      • 任意两个 Patch (( i 1 , j 1 i_1, j_1 i1,j1)) 和 (( i 2 , j 2 i_2, j_2 i2,j2)) 的相对位置为 (( i 1 − i 2 , j 1 − j 2 i_1 - i_2, j_1 - j_2 i1i2,j1j2))。
      • 相对位置的范围为 ( [ − M + 1 , M − 1 ] [-M+1, M-1] [M+1,M1])(例如 ([-6, 6])),因此需要一个 ( 2 M − 1 × 2 M − 1 2M-1 \times 2M-1 2M1×2M1) 的偏置矩阵 ( B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat{B} \in \mathbb{R}^{(2M-1) \times (2M-1)} B^R(2M1)×(2M1))(例如 ( 13 × 13 13 \times 13 13×13))。
      • ( B i , j B_{i,j} Bi,j) 是通过查询 ( B ^ \hat{B} B^) 得到的:( B i , j = B ^ [ i 1 − i 2 + M − 1 , j 1 − j 2 + M − 1 ] B_{i,j} = \hat{B}[i_1 - i_2 + M-1, j_1 - j_2 + M-1] Bi,j=B^[i1i2+M1,j1j2+M1])。
    • 这个偏置矩阵 ( B ^ \hat{B} B^) 是模型的参数,会在训练过程中学习。
  • 为什么不加到输入?

    • 如果将位置信息加到输入嵌入,位置嵌入通常是固定的(例如正弦位置编码)或预定义的(例如 ViT 中的可学习位置嵌入)。
    • Swin Transformer 的 (B) 是动态的,针对每个窗口内 Patch 的相对位置学习不同的偏置值。将其加到注意力计算中,可以更灵活地建模位置关系,而不会受到输入嵌入的限制。
(3) 窗口化设计的适配
  • 窗口内的局部性

    • Swin Transformer 的窗口化设计(W-MSA 和 SW-MSA)限制了自注意力计算的范围(仅在 ( M × M M \times M M×M) 窗口内)。
    • 在一个窗口内,Patch 之间的相对位置是有限的(范围为 ( [ − M + 1 , M − 1 ] [-M+1, M-1] [M+1,M1]))。相对位置偏置 ( B B B) 专门为这种局部窗口设计,能够高效地建模窗口内的空间关系。
  • 对比全局位置嵌入

    • 如果在输入时加位置嵌入(例如 ViT 的做法),位置嵌入需要为整个特征图的每个 Patch 分配一个全局位置编码(例如 ( H × W H \times W H×W) 个位置嵌入)。
    • 这种全局位置编码不适合 Swin Transformer 的窗口化设计,因为:
      1. 窗口化自注意力只关注局部区域,全局位置编码可能会引入不必要的复杂性。
      2. 移位窗口(SW-MSA)会改变 Patch 的空间位置,全局位置编码难以适应这种动态变化。
  • 相对位置偏置的优势

    • 相对位置偏置只依赖于 Patch 之间的相对距离,不依赖于它们的绝对位置。
    • 在移位窗口(SW-MSA)中,Patch 的绝对位置会因为循环移位而改变,但相对位置(在窗口内)保持不变。因此,相对位置偏置可以无缝适应移位操作。
(4) 计算效率
  • 参数效率

    • 如果在输入时加位置嵌入,需要为每个 Patch 分配一个位置向量(维度为 ( C C C)),总参数量为 ( H × W × C H \times W \times C H×W×C)。
    • Swin Transformer 的相对位置偏置只需要一个 ( 2 M − 1 × 2 M − 1 2M-1 \times 2M-1 2M1×2M1) 的矩阵(例如 (M=7) 时为 ( 13 × 13 = 169 13 \times 13 = 169 13×13=169) 个参数),参数量与特征图大小无关,更加高效。
  • 计算效率

    • 将 ( B B B) 加到 ( Q K T d \frac{QK^T}{\sqrt{d}} d QKT) 上,只需在注意力计算时进行一次加法操作,计算开销很小。
    • 如果加到输入嵌入,位置信息需要通过 ( Q = X W Q Q = X W_Q Q=XWQ)、( K = X W K K = X W_K K=XWK) 的线性变换传播,增加了计算负担。

3. 为什么不加到输入嵌入?

在传统的 Transformer(如 ViT)中,位置信息通常通过位置嵌入(Positional Embedding)加到输入中,例如:
X = X + PosEmbedding X = X + \text{PosEmbedding} X=X+PosEmbedding
然后再计算 ( Q = X W Q Q = X W_Q Q=XWQ)、( K = X W K K = X W_K K=XWK)。这种方式有以下局限性,不适合 Swin Transformer 的设计:

  1. 全局位置嵌入不适合窗口化

    • ViT 使用全局位置嵌入,因为它的自注意力是全局的(所有 Patch 之间都计算注意力)。
    • Swin Transformer 的自注意力限制在局部窗口内,全局位置嵌入会引入不必要的复杂性,且难以适应移位窗口的动态变化。
  2. 绝对位置 vs 相对位置

    • 全局位置嵌入通常是绝对位置编码(每个 Patch 有一个固定的位置向量)。
    • Swin Transformer 使用相对位置偏置,关注 Patch 之间的相对距离。这种相对位置信息更适合视觉任务中的局部性假设(空间上更近的 Patch 通常更相关)。
  3. 移位窗口的动态性

    • 在 SW-MSA 中,特征图会通过循环移位改变 Patch 的绝对位置。如果使用全局位置嵌入,移位后需要重新调整位置嵌入,增加了复杂性。
    • 相对位置偏置只依赖于窗口内的相对位置,不受移位操作的影响,更加灵活。
  4. 信息融合的阶段

    • 在输入时加位置嵌入,位置信息会与内容信息(( X X X))混合,经过 ( W Q W_Q WQ)、( W K W_K WK) 变换后,位置信息可能会被削弱或扭曲。
    • 直接将 ( B B B) 加到 ( Q K T d \frac{QK^T}{\sqrt{d}} d QKT) 上,位置信息在注意力计算的最后阶段引入,确保其对注意力权重的直接影响。

4. 相对位置偏置的实现细节

  • 偏置矩阵 (B)

    • ( B ∈ R M 2 × M 2 B \in \mathbb{R}^{M^2 \times M^2} BRM2×M2),表示窗口内任意两个 Patch 之间的相对位置偏置。
    • 实际实现中,Swin Transformer 使用一个较小的可学习矩阵 ( B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat{B} \in \mathbb{R}^{(2M-1) \times (2M-1)} B^R(2M1)×(2M1)),通过索引映射生成 ( B B B)。
    • 例如,( M = 7 M=7 M=7),( B ^ \hat{B} B^) 是一个 ( 13 × 13 13 \times 13 13×13) 的矩阵,存储所有可能的相对位置偏置值。
  • 计算过程

    • 对于窗口内的两个 Patch (( i 1 , j 1 ) i_1, j_1) i1,j1)) 和 (( i 2 , j 2 i_2, j_2 i2,j2)),相对位置为 ( ( i 1 − i 2 , j 1 − j 2 ) (i_1 - i_2, j_1 - j_2) (i1i2,j1j2))。
    • 查询 ( B ^ \hat{B} B^) 得到偏置值:( B i , j = B ^ [ i 1 − i 2 + M − 1 , j 1 − j 2 + M − 1 ] B_{i,j} = \hat{B}[i_1 - i_2 + M-1, j_1 - j_2 + M-1] Bi,j=B^[i1i2+M1,j1j2+M1])。
    • 将 ( B i , j B_{i,j} Bi,j) 加到 ( ( Q K T ) i , j d \frac{(Q K^T)_{i,j}}{\sqrt{d}} d (QKT)i,j) 上,影响注意力权重。

5. 总结

Swin Transformer 将相对位置偏置 ( B B B) 加到注意力计算中(( Q K T d + B \frac{QK^T}{\sqrt{d}} + B d QKT+B)),而不是在输入时与嵌入相加,主要原因如下:

  1. 直接影响注意力权重:( B B B) 直接调整注意力权重,显式地引入空间位置信息,符合视觉任务的局部性需求。
  2. 动态性和灵活性:相对位置偏置是可学习的,针对窗口内的相对位置设计,能够适应移位窗口的动态变化。
  3. 窗口化设计的适配:相对位置偏置更适合局部窗口自注意力,避免了全局位置嵌入的复杂性。
  4. 计算和参数效率:加到注意力计算中更高效,参数量小且计算开销低。

这种设计使得 Swin Transformer 能够在保持计算效率的同时,高效地建模空间位置信息,从而在多种视觉任务中表现出色。

Swin Transformer 的完整代码实现

下面将提供 Swin Transformer 的完整代码实现,包括模型定义、训练代码和推理代码,并详细解释每个部分。Swin Transformer 是一种层次化的视觉 Transformer,广泛用于图像分类、目标检测等任务。我们将以图像分类任务为例,使用 PyTorch 实现 Swin Transformer(Swin-T 变体),并基于 ImageNet 数据集进行训练和推理。


1. Swin Transformer 模型代码实现

我们首先实现 Swin Transformer 的模型结构,包括 Patch Partition、Patch Merging、Swin Transformer Block(W-MSA 和 SW-MSA)、以及整体架构。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 工具函数:将图像划分为窗口
def window_partition(x, window_size):
    """
    将特征图划分为非重叠的窗口
    Args:
        x: (B, H, W, C)
        window_size (int): 窗口大小 (M)
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

# 工具函数:将窗口还原为特征图
def window_reverse(windows, window_size, H, W):
    """
    将窗口还原为特征图
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): 窗口大小 (M)
        H, W (int): 特征图的高度和宽度
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

# 实现多头自注意力(W-MSA 和 SW-MSA)
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # M
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # 相对位置偏置表
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # (2M-1)*(2M-1), num_heads

        # 相对位置索引
        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, M, M
        coords_flatten = torch.flatten(coords, 1)  # 2, M*M
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, M*M, M*M
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # M*M, M*M, 2
        relative_coords[:, :, 0] += window_size - 1  # 调整到 [0, 2M-2]
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)  # M*M, M*M
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: (num_windows*B, N, C), N 是窗口内的 Patch 数量 (M*M)
            mask: (num_windows, M*M, M*M), 注意力掩码
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (num_windows*B, num_heads, N, C/num_heads)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (num_windows*B, num_heads, N, N)

        # 相对位置偏置
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size * self.window_size, self.window_size * self.window_size, -1)  # M*M, M*M, num_heads
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # num_heads, M*M, M*M
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Swin Transformer Block
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(
            dim, window_size=window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

        # 如果需要移位,计算注意力掩码
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # num_windows, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # 循环移位
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # 窗口划分
        x_windows = window_partition(shifted_x, self.window_size)  # num_windows*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # W-MSA 或 SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        # 窗口还原
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # 逆移位
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

# Patch Merging
class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        # 合并 2x2 的 Patch
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B (H/2*W/2) 4*C

        x = self.norm(x)
        x = self.reduction(x)  # B (H/2*W/2) 2*C

        return x

# Swin Transformer 整体架构
class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        # Patch Partition
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # 构建每个 Stage
        self.layers = nn.ModuleList()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # Drop Path Rate
        for i_layer in range(self.num_layers):
            layer = nn.ModuleList([
                SwinTransformerBlock(
                    dim=int(embed_dim * 2 ** i_layer),
                    input_resolution=(img_size // (4 * 2 ** i_layer), img_size // (4 * 2 ** i_layer)),
                    num_heads=num_heads[i_layer],
                    window_size=window_size,
                    shift_size=0 if (i % 2 == 0) else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[sum(depths[:i_layer]) + i],
                )
                for i in range(depths[i_layer])
            ])
            self.layers.append(layer)

            # Patch Merging(最后一个 Stage 不需要)
            if i_layer < self.num_layers - 1:
                self.layers.append(PatchMerging(
                    input_resolution=(img_size // (4 * 2 ** i_layer), img_size // (4 * 2 ** i_layer)),
                    dim=int(embed_dim * 2 ** i_layer)
                ))

        # 分类头
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # Patch Partition
        x = self.patch_embed(x)  # (B, embed_dim, H/4, W/4)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, H/4*W/4, embed_dim)
        x = self.pos_drop(x)

        # 逐层处理
        for i, layer in enumerate(self.layers):
            if isinstance(layer, PatchMerging):
                x = layer(x)
                _, L, _ = x.shape
                H, W = int(math.sqrt(L)), int(math.sqrt(L))
            else:
                for blk in layer:
                    x = blk(x)

        # 分类
        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))  # (B, C, 1)
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x

# 实例化 Swin-T
def swin_tiny(**kwargs):
    model = SwinTransformer(
        embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
        window_size=7, mlp_ratio=4., drop_path_rate=0.1, **kwargs)
    return model

2. 训练代码

以下是基于 PyTorch 的训练代码,假设使用 ImageNet 数据集(需要自行准备数据)。我们将使用 torchvision 提供的预处理和数据加载功能。

import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

# 数据预处理
def get_data_loaders(batch_size=64, num_workers=4):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 假设 ImageNet 数据集路径为 'path_to_imagenet'
    train_dataset = datasets.ImageFolder('path_to_imagenet/train', transform=transform_train)
    test_dataset = datasets.ImageFolder('path_to_imagenet/val', transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

# 训练函数
def train(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {running_loss / (batch_idx + 1):.3f}, '
                  f'Acc: {100. * correct / total:.2f}%')

# 验证函数
def validate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    print(f'Test Loss: {test_loss / len(test_loader):.3f}, Test Acc: {acc:.2f}%')
    return acc

# 主训练循环
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = swin_tiny(num_classes=1000).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

    train_loader, test_loader = get_data_loaders(batch_size=64)

    best_acc = 0.0
    for epoch in range(100):
        train(model, train_loader, criterion, optimizer, device, epoch)
        acc = validate(model, test_loader, criterion, device)
        scheduler.step()

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'swin_tiny_best.pth')
            print(f'Saved best model with acc: {best_acc:.2f}%')

if __name__ == "__main__":
    main()

3. 推理代码

以下是推理代码,用于加载训练好的模型并对单张图像进行分类。

import torch
from PIL import Image
import torchvision.transforms as transforms

# 加载模型
def load_model(model_path, device):
    model = swin_tiny(num_classes=1000).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

# 图像预处理
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # (1, 3, 224, 224)
    return image

# 推理
def infer(model, image_path, device):
    image = preprocess_image(image_path).to(device)
    with torch.no_grad():
        output = model(image)
        _, predicted = output.max(1)
        prob = torch.softmax(output, dim=1)[0, predicted].item()
    return predicted.item(), prob

# 主函数
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = 'swin_tiny_best.pth'
    image_path = 'path_to_image.jpg'

    model = load_model(model_path, device)
    class_idx, prob = infer(model, image_path, device)

    # 假设有 ImageNet 类别标签文件
    with open('imagenet_classes.txt', 'r') as f:
        classes = [line.strip() for line in f.readlines()]
    
    print(f'Predicted class: {classes[class_idx]}, Probability: {prob:.3f}')

if __name__ == "__main__":
    main()

4. 代码详细解释

4.1 模型代码(Swin Transformer)
  • WindowAttention

    • 实现了窗口化的多头自注意力(W-MSA 和 SW-MSA)。
    • 包含相对位置偏置(Relative Position Bias),通过一个可学习的偏置表 (B) 引入空间信息。
    • 支持注意力掩码(用于 SW-MSA),避免移位后不自然的跨边界交互。
  • SwinTransformerBlock

    • 包含两个子模块:窗口自注意力(W-MSA 或 SW-MSA)和 MLP。
    • 实现循环移位(torch.roll)和注意力掩码,用于 SW-MSA。
    • 交替使用 W-MSA 和 SW-MSA(通过 shift_size 控制)。
  • PatchMerging

    • 实现 ( 2 × 2 2 \times 2 2×2) Patch 的合并,将空间分辨率减半,特征维度从 (C) 变为 (2C)。
    • 通过拼接和线性降维完成。
  • SwinTransformer

    • 整体架构,包含 Patch Partition(初始分块)、多个 Stage(每个 Stage 包含若干 Block 和 Patch Merging)、以及分类头。
    • 参数配置(如 depthsnum_heads)对应 Swin-T 变体。
4.2 训练代码
  • 数据加载

    • 使用 torchvisionImageFolder 加载 ImageNet 数据集。
    • 应用数据增强(随机裁剪、翻转)和标准化(ImageNet 均值和方差)。
  • 训练函数

    • 标准 PyTorch 训练流程:前向传播、计算损失、反向传播、优化。
    • 使用 AdamW 优化器和余弦退火学习率调度(CosineAnnealingLR)。
  • 验证函数

    • 在验证集上评估模型,计算损失和准确率。
    • 保存最佳模型(基于验证集准确率)。
4.3 推理代码
  • 模型加载

    • 加载训练好的模型权重,设置为评估模式(model.eval())。
  • 图像预处理

    • 调整图像大小、裁剪、标准化,与训练时一致。
  • 推理

    • 前向传播,获取预测类别和置信度。
    • 使用 ImageNet 类别标签文件将类别索引转换为类别名称。

5. 使用说明

  1. 环境准备

    • 安装 PyTorch 和 torchvision:
      pip install torch torchvision
      
    • 准备 ImageNet 数据集(或替换为其他数据集)。
  2. 训练

    • 修改 path_to_imagenet 为你的 ImageNet 数据集路径。
    • 运行训练代码:
      python train_swin.py
      
  3. 推理

    • 准备 imagenet_classes.txt(ImageNet 类别标签文件)。
    • 修改 path_to_image.jpg 为测试图像路径。
    • 运行推理代码:
      python infer_swin.py
      

6. 扩展与优化

  • 多 GPU 训练
    • 使用 torch.nn.DataParalleltorch.distributed 实现多 GPU 训练。
  • 预训练模型
    • 官方 Swin Transformer 提供预训练权重,可以通过 timm 库加载:
      import timm
      model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
      
  • 下游任务
    • 目标检测:将 Swin Transformer 作为主干网络,接入 Faster R-CNN 或 Mask R-CNN。
    • 语义分割:接入 UPerNet 或其他分割头。

希望这个实现和解释对你有帮助!

后记

2025年3月21日16点23分于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值