【SwinTransformer源码阅读二】Window Attention和Shifted Window Attention部分

本文详细解析了SwinTransformer中Window-based Multi-Head Self-Attention(W-MSA)和Shifted W-MSA(SW-MSA)的实现。W-MSA关注单个窗口内的信息,而SW-MSA通过窗口滑动和mask操作引入跨窗口信息。关键在于相对位置偏置表的构建和使用,以及在SW-MSA中通过特征图循环移位和mask注意力机制实现信息交互。SwinTransformer通过连续的W-MSA和SW-MSA层,逐渐增强模型对全局特征的捕获能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

先放一下SwinTransformer的整体结构,图片源于原论文,可以发现,在Transformer的Block中 W-MSA(Window based multi-head self attention) 和 SW-MSA是关键组成部分。W-MSA出现在某阶段的奇数层,SW-MSA出现在某阶段的偶数层,W-MSA考虑的是单个窗口的信息,SW-MSA考虑的是不同窗口间的信息。

在这里插入图片描述

虽然从网络架构图里看,W-MSA和SW-MSA为两个不同的模块,但是在代码层面,两者是同一个代码片段,只是在计算SW-MSA时候,在计算完W-MSA后,然后通过代码进行滑动窗口,即cyclic shift操作,多计算了一个mask的操作。下面将针对代码进行分析。

W-MSA的代码

【注意】注释第一句话:Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.
代码注释中的中文,是以配置文件中 swin-tiny 相关的量 来进行注释的。

#窗口注意力
class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim#96*(2^layer_index 0,1,2,3...)
        self.window_size = window_size  # Wh, Ww (7,7)
        self.num_heads = num_heads#[3, 6, 12, 24]
        head_dim = dim // num_heads#(96//3=32,96*2^1 // 6=32,...)
        self.scale = qk_scale or head_dim ** -0.5#default:head_dim ** -0.5

        # define a parameter table of relative position bias
        #定义相对位置偏置表格
        #[(2*7-1)*(2*7-1),3]
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        #得到一对在窗口中的相对位置索引
        coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        #让相对坐标从0开始
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        #relative_coords[:, :, 0] * (2*7-1)
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        #为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
        #索引值为 (49,49) 值在0-168对应位置偏移表的索引 
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
        #dim*(dim*3)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        #attn_drop=0.0
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        #初始化相对位置偏置值表(截断正态分布)
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
    #模块的前向传播
    def forward(self, x,
### W-MSA 模块的工作原理与架构 W-MSA(Window-based Multi-head Self Attention)是一种基于窗口的多头自注意力机制,主要用于减少全局计算复杂度并提高局部特征提取能力。以下是关于其工作原理架构的具体描述: #### 1. **基本概念** W-MSA 将输入特征图划分为多个不重叠的小窗口,在这些固定大小的窗口内部执行 multi-head self-attention (MHSA)[^1]。通过这种方式,可以显著降低计算量,因为 MHSA 的操作仅限于较小范围内的像素。 #### 2. **划分方式** 假设输入特征图为 \( H \times W \),将其分割成若干个尺寸为 \( M \times M \) 的窗口,则每个窗口中的 tokens 形成了独立子集用于后续处理过程。这种策略不仅简化了运算流程还增强了模型捕捉细粒度模式的能力。 #### 3. **具体实现细节** 在实际应用过程中,对于每一个单独定义好的窗口区域,按照标准形式化表达式来构建 key, query value 向量集合,并依据 scaled dot-product attention 方法完成最终输出向量矩阵生成步骤如下所示: ```python def window_attention(q, k, v): attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) attn_probs = nn.Softmax(dim=-1)(attn_scores) output = torch.matmul(attn_probs, v) return output ``` 上述代码片段展示了如何在一个特定窗口内利用 Query(Q), Key(K), Value(V)三者之间相互作用关系得到新的表示结果。 #### 4. **对比其他变体** 值得注意的是,除了常规版本外还有另一种叫做 Shifted Window MSAs(SW-MSAs)的技术方案被引入进来作为补充选项之一。它先经过一次周期位移变换(cyclic shift operation),然后再按前述方法实施相同类型的局部关注机制;最后再逆向恢复原始布局结构以便继续下一层级的操作[^3]。 综上所述,W-MSA 主要依靠限定区域内进行高效的信息交互从而达到优化性能的目的同时保留足够的空间分辨率优势。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值