Swin Transformer之Mask和相对位置编码代码详解

Swin Transformer的详细原理我已经在上一篇文章写过了,这回我来细细的写一篇它的代码原理。有朋友跟我反应Vit代码直接全贴上去光靠注释也不容易看懂,这会我用分总的方法介绍。

注:此代码支持多尺度训练。

文章仅供学习

先从最难的下手。

SW-MSA之mask

    def create_mask(self, x, H, W):
        # 第一部分:初始化
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) 
        # 第二部分:编号,分窗口
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
        mask_windows = window_partition(img_mask, self.window_size)  
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        # 第三部分:生成蒙版 
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

第一部分:初始化

        因为蒙版要和attention中的矩阵相加,但attention中输入的特征图已经被padding到window_size的整数倍,所以蒙版的维度也要是window_size的整数倍。先将他初始化为(1,Hp,Wp,1)的0矩阵,之所以维度顺序如此是因为在window_partition中的维度顺序如此,我们稍后再讲。

第二部分:编号,分窗口

  

        如图,这是上一篇原理文章中shifted后再平移后对应的窗口,建议先对蒙版的原理聊熟于心,可以看我的上一篇文章。这里的“编号”就是将刚刚生成的全0的蒙版分成这样的区域。

        这里的前提条件是蒙版是window_size的整数倍。所以[0, -window_size; 0, -window_size]的所有窗口内都来自区域0。

        因为我们是移动了shift_size个单位,再补到其他地方去。因此[0, -window_size; -window_size, -shift_size]内都来自区域1。[0, -window_size; -shift_size, None]都来自区域2。以此类推。

        slice函数就是切片的作用,他的元素可以用于在张量中取值,如图

        紧接着我们就可以通过一个不断增加的变量cnt将这九个区域全部编上号

        经过window partition后的维度是[nW, Mh, Mw, 1],nW是窗口数,再将其展平,维度变为[nW, Mh*Mw]

第三部分:生成蒙版

        这里主要用了广播机制。

        我们知道,现在一个窗口的值代表对应patch的区域,一个窗口有Mh*Mw和值,每个值的维度为1.展平后就是取消最后一个维度,并且将二维窗口变为一维。如下图

        通过unsqueeze实现两个张量相减,维度如下[nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]。抛开第一个维度不看,于是我们取最后一行举例,这相当于一个Mh*Mw的行向量减去一个Mh*Mw的列向量,广播以后如下

A                                                         B  

        这么理解,A的每一行按红框的数字代表区域几进行attention,根据我们mask的原理只能和来自相同区域进行attention,来自不同的区域要被mask掉,所以减去B后,为0就代表不需要mask,不为0就代表需要mask。作者使用-100进行mask,通过softmax后就基本为0了。

        最终,每一个window的维度为[1, Mh*Mw, Mh*Mw],所以最终mask矩阵为[nW, Mh*Mw, Mh*Mw]

 相对位置编码

        # 第一部分:相对位置表
        self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  
        # 第二部分:相对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  
        coords_flatten = torch.flatten(coords, 1)  
        relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1) 
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  
        relative_coords[:, :, 0] += self.window_size[0] - 1  
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  
        self.register_buffer("relative_position_index", relative_position_index)


# forward方法中:
        ......
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  
        attn = attn + relative_position_bias.unsqueeze(0)
        ......

        相对位置编码我们没有专门做成一个模块,而是直接放在WindowAttention当中。

第一部分:相对位置表

        首先参数化一个[2*Mh-1 * 2*Mw-1, nH]的张量,nH的意思是每个head都有自己的可学习的相对位置表,至于为什么是2*Mh-1 * 2*Mw-1可以见原理篇。

第二部分:相对位置索引

        首先先构建一个绝对位置表,最后的coords如下。这里coords[0]代表行,[1]代表列。

        然后再对它进行展平

        最关键的来了,又是一句代码解决,又是用了广播机制

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]

        维度是[2, Mh*Mw, 1] - [2, 1, Mh*Mw],再用一次广播机制。

        举个例子现在有四个位置,表格中是他们的绝对位置,现在要求他们的相对位置索引。在图A中,左右矩阵相减可以看做是各个绝对位置依次去减其他的绝对位置,最终得到相对位置索引矩阵

0,00,1
1,01,1
A    

         然后将二维索引转为一维索引,先给行列坐标都加上2M-1以消除负数,再给行标乘2M-1,再将行列相加,最终得到相对位置索引。

B
​​​​

        在整个训练过程中,窗口大小永远不变,而相对位置索引仅跟窗口有关,所以将我们得到的这个相对位置索引矩阵存在缓存之中。

 第三部分:使用它

        在forward方法中,首先把它的数据取出来。在张量a,b可以使用a[b],使用效果如下

         也就是说会根据b的内容取值,

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
    self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)

        因此这一句我们先将relative_position_index变为一维,再根据他之内的索引在relative_position_bias_table中取值,然后变成[Mh*Mw,Mh*Mw,nH]的形状,通过permute变成[nH, Mh*Mw, Mh*Mw],最后再unsqueeze上batch的维度和attention相加,得到结果。

以上是最难理解的两部分,全部代码我会放在下一篇文章说明。

另:本篇文章参考自12.2 使用Pytorch搭建Swin-Transformer网络_哔哩哔哩_bilibili

仅供学习

<think>嗯,用户现在想了解T5模型中相对位置编码函数的定义、作用实现方式,以及在序列生成中的意义。之前已经回答过关于交叉熵损失函数位置编码的部分,可能需要回顾之前的对话确保一致性。用户提供的引用中有提到Hugging FaceGoogle Research的资源,可能需要结合这些信息。 首先,需要明确相对位置编码绝对位置编码的区别。在之前的回答里提到过T5使用相对位置偏置,添加到注意力分数中。现在需要更详细地解释这个机制的具体实现,比如偏置项是如何计算的,参数如何训练。 用户可能想知道为什么T5选择相对位置编码,它的优势在哪里。之前提到过长序列处理动态文本长度的适应性,可能需要进一步展开,比如在处理不同任务时的灵活性,或者对模型泛化能力的影响。 另外,用户提到实现方式,可能需要具体到公式代码层面的解释。例如,T5中的相对位置偏置矩阵的结构,如何与Q、K、V结合,是否在每一层都应用。同时,需要确保引用之前的回答中的相关部分,比如交叉注意力自注意力的区别。 还要注意用户提供的引用[3]中提到T5的预训练微调,可能需要联系到相对位置编码在预训练中的作用,比如在Span Corruption任务中如何处理被掩盖的文本片段的位置信息。 是否需要比较其他模型如BERT或GPT的位置编码方式,来突出T5的相对位置编码特点?不过用户的问题主要集中在T5本身,所以可能不需要过多比较,但简要提及可以增强理解。 需要确认技术细节的正确性,比如相对位置编码的具体公式,是否每个头都有独立的偏置,还是共享参数。根据T5论文,相对位置偏置是每个注意力头独立学习的,这可能值得说明。 最后,总结相对位置编码在序列生成中的意义,比如提升模型对位置关系的敏感度,增强长距离依赖的捕捉能力,从而改善生成文本的连贯性准确性。</think>### T5模型中相对位置编码的深度解析 --- #### **一、相对位置编码的定义与数学形式** T5采用的**相对位置编码(Relative Positional Encoding)**并非直接为每个位置赋予独立向量,而是通过**相对位置偏置项**动态调整注意力权重。其核心公式定义如下: 1. **偏置项生成** 给定序列中两个位置$i$$j$,其相对距离为$d = i - j$,对应的偏置项$b_d$通过可学习的参数矩阵生成: $$ b_d = \text{LearnableBias}(d) \quad \text{其中} \quad d \in [-k, k] $$ $k$为预设的最大偏移范围(T5默认$k=128$),超出范围的$d$会被截断到$[-k, k]$区间[^1][^4]。 2. **注意力分数修正** 在计算自注意力分数时,将偏置项添加到查询-键点积结果中: $$ \text{AttentionScore}_{i,j} = Q_i K_j^T + b_{i-j} $$ 其中$Q_i$$K_j$分别表示位置$i$的查询向量位置$j$的键向量[^4]。 --- #### **二、实现细节与代码映射** 1. **参数存储方式** T5将相对位置偏置存储为一个形状为$(2k+1, h)$的矩阵($h$为注意力头数),每个头独立学习不同距离的偏置值。例如,当$k=128$时,可覆盖257种相对位置关系(从$-128$到$128$)[^4]。 2. **Hugging Face实现代码片段** ```python # 伪代码示例(基于Hugging Face T5实现) class T5RelativePositionBias(nn.Module): def __init__(self, num_heads, max_distance=128): super().__init__() self.num_heads = num_heads self.max_distance = max_distance self.bias = nn.Parameter(torch.zeros(2 * max_distance + 1, num_heads)) def forward(self, seq_len): # 生成相对位置索引矩阵 context_position = torch.arange(seq_len)[:, None] memory_position = torch.arange(seq_len)[None, :] relative_position = memory_position - context_position # shape: (seq_len, seq_len) # 截断超出范围的相对位置 relative_position = torch.clamp(relative_position, -self.max_distance, self.max_distance) # 转换为参数矩阵的索引(将负数偏移映射到正数索引) relative_position_index = relative_position + self.max_distance # 提取偏置项并调整维度 bias = self.bias[relative_position_index].permute(2, 0, 1) # shape: (h, seq_len, seq_len) return bias ``` 该模块在计算注意力时被调用,结果直接叠加到注意力分数矩阵上[^2][^4]。 --- #### **三、作用与优势分析** 1. **动态位置感知能力** - **局部依赖增强**:相对位置编码更关注相邻词的位置关系(如"动词+宾语"的近距离依赖),而非绝对位置。 - **长程依赖支持**:通过截断机制,模型能有效处理长文本中跨度较大的依赖(如文档级指代消解)[^4]。 2. **技术优势对比绝对编码** | **特性** | **绝对位置编码(如BERT)** | **T5相对位置编码** | |-------------------|------------------------------------------|----------------------------------------| | **位置表示方式** | 固定位置ID嵌入叠加到输入向量 | 动态偏置项添加到注意力分数 | | **长度泛化性** | 受限于预训练最大长度(如512) | 支持任意长度输入(仅依赖偏置截断范围) | | **计算开销** | 低(静态参数) | 中等(需实时计算相对位置索引) | | **任务适配性** | 适合短文本分类 | 更适合生成任务长文本处理 | 3. **序列生成中的关键意义** - **连贯性保障**:在生成文本时,模型能更精准地捕捉当前词与历史词的位置关系(如保证代词与先行词的一致性)。 - **多任务统一**:相对位置编码不依赖任务特定的位置模式,适配翻译、摘要、问答等多种生成场景[^1][^3]。 --- #### **四、在预训练与微调中的具体应用** 1. **预训练阶段(Span Corruption)** 当输入文本被随机掩码时,相对位置编码帮助模型: - 识别被掩码片段内部的位置关系(如"Paris is the [MASK] of France"中推断[MASK]应为"capital")。 - 保持未被掩码部分的上下文连贯性[^1][^3]。 2. **微调阶段(如机器翻译)** 在英法翻译任务中,模型通过相对位置编码- 对齐源语言与目标语言的词序差异(如英语的形容词前置与法语后置)。 - 处理长句中的跨语言结构对应(如从句嵌套)[^3]。 --- ### **典型问题解析** **Q: 为什么相对位置编码能更好地处理长文本?** **A**: 绝对编码需要预先定义最大序列长度(如512),而相对编码通过偏置截断(如$k=128$)允许模型处理任意长度文本。即使序列超过训练时见过的长度,模型仍能通过局部位置关系进行推理[^4]。 **Q: 相对位置编码是否增加训练难度?** **A**: 是的,因为每个注意力头需额外学习$2k+1$个偏置参数。但实际训练中,这些参数通过大规模预训练数据可稳定收敛,最终提升模型泛化能力[^4]。 --- ### 相关问题 1. T5如何处理超出预设最大偏移范围(如$k=128$)的相对位置? 2. 相对位置编码在多头注意力机制中是否每个头独立学习偏置? 3. 如何通过可视化方法观察相对位置偏置项的学习效果?
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值