一口气看完TransformerXL详细解释带代码一:相对位置编码

目录

一、相对位置编码概述

1. 为什么需要相对位置编码?​​

2. 核心思想​​

3. 公式​

4. Transformer-XL 的改进​​

 二、代码实现:

1、shift_right函数

​1. 功能说明​​

2. 代码逐行解析​​

​​(1) 补零列​​

(2) 重塑并截断​​

 3.完整代码

2.RelativeMultiHeadAttention​​ 类

​1. 核心功能​​

2. 关键组件解析​​

​​(1) 初始化参数​​

​​(2) 相对位置编码参数​​

3. 相对位置编码的原理​​

​​(1) 相对位置的作用​​

​​(2) 实现方式​​

​​(3) 数学形式​​

 4.完整代码

5.getsorce()函数,计算q和k相对位置编码得分

​1. 函数功能​​

2. 关键步骤解析​​

​​(1) 获取相对位置嵌入和偏置​​

​​(2) 查询位置偏置​​

​​(3) 计算内容相关分数 (ac)​​

(4) 计算相对位置相关分数 (b 和 d)​​

​(5) 合并相对位置分数 (bd)​​

​​(6) 最终分数​​

 3.完整代码

 三、完整代码带测试


这里的代码来自:Transformer XL,这里对代码做了详细解释

一、相对位置编码概述

相对位置编码(Relative Position Encoding)是Transformer模型中用于捕捉序列元素间相对位置关系的一种方法,与绝对位置编码(如BERT中的固定位置嵌入)不同,它直接建模元素之间的相对距离,从而更灵活地处理长序列或不同长度的序列。以下是详细解释:

1. 为什么需要相对位置编码?​

  • ​绝对位置编码的局限​​:传统Transformer使用正弦/余弦函数或可学习的绝对位置嵌入,为每个位置分配固定编码。但这种方式无法直接建模元素间的相对关系(例如距离为2的两个词的关系)。
  • ​相对位置的优势​​:自然语言中,词语的语义常依赖相对位置(如"相邻"或"距离为3"),相对位置编码能更直接地捕捉这种关系。

 

2. 核心思想​

相对位置编码通过以下方式改进注意力机制:

  • ​键-值分离​​:在计算注意力权重时,不仅考虑内容的匹配(Query和Key的点积),还显式加入相对位置的影响。
  • ​距离敏感​​:相对位置的权重随距离增大而衰减,符合语言中局部依赖更强的特性。

 

3. 公式

在论文《Self-Attention with Relative Position Representations》中,作者提出:

  • ​相对位置嵌入矩阵​​:定义一个可学习的矩阵 P∈R(2k+1)×d,其中 k 是最大相对距离(如窗口大小),d 是维度。
  • ​修改注意力计算​​:
    • ​绝对注意力​​:
    • ​加入相对位置​​:
    • 其中 Pi−j​ 是从 P 中查到的相对位置嵌入(若 ∣i−j∣>k,则截断)。

 

4. Transformer-XL 的改进​

在《Transformer-XL》中,作者进一步优化:

  • ​相对位置嵌入与内容解耦​​:将位置信息分离,避免与内容混淆:
    • u,v 是可学习参数,分别捕捉内容无关的位置偏置。
方法公式
绝对位置编码
Shaw相对位置
Transformer-XL

 二、代码实现:

1、shift_right函数

​1. 功能说明​

  • ​输入​​:任意维度的张量 x,但要求至少是 2D 的(例如形状 [B, C, H, W],常见于图像数据)。
  • ​输出​​:将 x 的​​第 1 维(dim=1)的数据向右平移一位​​,最左侧补零,最右侧的数据被丢弃。
  • ​效果举例​​:
    • 若输入是 [[1, 2, 3], [4, 5, 6]](形状 [2, 3]),输出为 [[0, 1, 2], [0, 4, 5]]

 

2. 代码逐行解析​

​(1) 补零列​
zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
x_padded = torch.cat([x, zero_pad], dim=1)
  • zero_pad​:创建一个全零张量,形状与 x 的第 0 维和其他维度相同,但第 1 维是 1(即一列零)。
    • 例如,若 x 形状为 [B, C, H, W],则 zero_pad 形状为 [B, 1, H, W]
  • torch.cat​:将 zero_pad 拼接到 x 的​​右侧​​(沿 dim=1),此时 x_padded 形状变为 [B, C+1, H, W]

 

(2) 重塑并截断​
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[:-1].view_as(x)
  • view 操作​​:将 x_padded 的形状从 [B, C+1, ...] 重塑为 [C+1, B, ...](交换第 0 维和第 1 维)。
    • 目的是通过后续切片 [:-1] 直接丢弃最后一个元素(即原张量的最右侧数据)。
  • view_as​:恢复原始形状,完成右移。

 

假设输入 x 为 2D 张量:

x = torch.tensor([[1, 2, 3], 
                  [4, 5, 6]])  # shape [2, 3]

运行后输出:

[[0, 1, 2],
 [0, 4, 5]]

 3.完整代码

# 数据右移函数
def shift_right(x:torch.Tensor):  # x.shape=(seq_len,batch_size,heads,d_k)
    # 创建一个第一维(从0开始)和x不一样的全为0的tensor
    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])  # shape(seq_len,1,batch_size,heads,d_k)
    # inspect(zero_pad)
    # 将x和零矩阵按第一维结合
    x_padded = torch.cat([x, zero_pad], dim=1)  #shape(seq_len,batch_size+1,heads,d_k)
    # inspect(x_padded)
    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])  # shape(batch_size+1,seq_len,heads,d_k)
    # inspect(x_padded)
    x = x_padded[:-1].view_as(x)
    return x

2.RelativeMultiHeadAttention​ 类

 定义了一个 ​RelativeMultiHeadAttention​ 类,继承自标准的 MultiHeadAttention,并引入了​​相对位置编码​​机制(Relative Positional Encoding)

​1. 核心功能​

  • ​目标​​:在多头注意力机制中,通过​​相对位置编码​​(而非绝对位置)增强模型对序列中元素相对位置的感知能力。
  • ​适用场景​​:机器翻译、语音识别等需要建模序列中元素相对位置的任务。

 

2. 关键组件解析​

​(1) 初始化参数​
super().__init__(heads, d_model, dropout_prob, bias=False)
  • 继承父类 MultiHeadAttention,但​​禁用线性变换的偏置​​(bias=False),因为相对位置编码会显式引入偏置(key_pos_bias 和 query_pos_bias)。
​(2) 相对位置编码参数​
  • self.P = 2 ​**​ 12
    定义最大相对位置距离(4096),即允许模型处理的最远相对位置偏移量。

  • self.key_pos_embeddings

    • 形状:[2P, heads, d_k]d_k = d_model // heads)。
    • 作用:为每个头(heads)和每个可能的相对位置(从 -P 到 P)学习独立的嵌入向量。
    • 例如,若 P=2,则编码范围为 [-2, -1, 0, 1, 2]
  • self.key_pos_bias

    • 形状:[2P, heads]
    • 作用:为每个头和相对位置学习偏置项,增强位置感知。
  • self.query_pos_bias

    • 形状:[heads, d_k]
    • 作用:与查询位置无关的全局偏置,用于调整查询向量的表示。

 

3. 相对位置编码的原理​

​(1) 相对位置的作用​
  • 在计算注意力分数时,不仅考虑键(Key)和查询(Query)的内容相似性,还考虑它们的​​相对位置距离​​。
  • 例如,在句子中,距离较近的词对通常比距离远的词对更相关。
​(2) 实现方式​
  • ​键的相对位置嵌入​​:
    对每个键位置 j 和查询位置 i,使用 key_pos_embeddings[j - i + P] 获取嵌入(+P 将负偏移映射到正索引)。
  • ​偏置项​​:
    key_pos_bias 和 query_pos_bias 直接加到注意力分数上,无需与内容交互。
​(3) 数学形式​

注意力分数计算扩展为:

其中:

  • u:query_pos_bias(与位置无关)。
  • vj−i​:key_pos_embeddings(相对位置嵌入)。
  • bj−i​:key_pos_bias(相对位置偏置)。

 4.完整代码

# 重写多头注意力模块
class RelativeMultiHeadAttention(MultiHeadAttention):
    def __init__(self,heads,d_model,dropout_prob):
        super().__init__(heads,d_model,dropout_prob,bias=False)
        # 定义相对位置范围:4090
        self.P=2**12
        # 为每个头和每个可能的相对位置学习独立的嵌入向量(从-P到P),形状(2*P,heads,d_k)
        self.key_pos_embeddings=nn.Parameter(torch.zeros((self.P*2,heads,self.d_k)),requires_grad=True)
        # 为每个头和相对位置学习偏置项,增强位置感知(2*P,heads)
        self.key_pos_bias=nn.Parameter(torch.zeros((self.P*2,heads)),requires_grad=True)
        # 与查询位置无关的全局偏置,用于调整查询向量的表示(heads,d_k)
        self.query_pos_bias=nn.Parameter(torch.zeros((heads,self.d_k)),requires_grad=True)

   

5.getsorce()函数,计算q和k相对位置编码得分

这段代码实现了 ​RelativeMultiHeadAttention​ 中的注意力分数计算部分,结合了内容相关性和相对位置信息。

​1. 函数功能​
  • ​输入​​:
    • query: 形状 [seq_len_q, batch, heads, d_k]
    • key: 形状 [seq_len_k, batch, heads, d_k]
  • ​输出​​:注意力分数矩阵,形状 [seq_len_q, seq_len_k, batch, heads],包含内容与相对位置的综合得分。

 

2. 关键步骤解析​
​(1) 获取相对位置嵌入和偏置​
key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]
key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]
  • key_pos_embeddings​:从预定义的相对位置嵌入矩阵中切片,选取与当前序列长度相关的部分。
    • 切片范围 [P - seq_len_k : P + seq_len_q],覆盖所有可能的相对位置偏移(从 -seq_len_k 到 seq_len_q)。
    • 形状:[seq_len_q + seq_len_k, heads, d_k]
  • key_pos_bias​:同理切片相对位置偏置,形状 [seq_len_q + seq_len_k, heads]
​(2) 查询位置偏置​
query_pos_bias = self.query_pos_bias[None, None, :, :]
  • 扩展全局查询偏置 query_pos_bias 的维度,形状从 [heads, d_k] 变为 [1, 1, heads, d_k],便于广播。
​(3) 计算内容相关分数 (ac)​
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
  • ​公式​​:
    • query + query_pos_bias:将查询向量与全局偏置结合(v 对应 query_pos_bias)。
    • einsum 计算查询和键的点积,生成基础注意力分数矩阵,形状 [seq_len_q, seq_len_k, batch, heads]
(4) 计算相对位置相关分数 (b 和 d)​
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
d = key_pos_bias[None, :, None, :]
  • b(公式 $ B'_{i,k} = Q_i^\top R_k \)​​:
    • 计算查询与相对位置嵌入的点积,形状 [seq_len_q, seq_len_q + seq_len_k, batch, heads]
  • d(公式 ( D'_{i,k} = S_k \)​​:
    • 直接扩展相对位置偏置的维度,形状 [1, seq_len_q + seq_len_k, 1, heads]
​(5) 合并相对位置分数 (bd)​
bd = shift_right(b + d)
bd = bd[:, -key.shape[0]:]
  • shift_right​:将 b + d 的矩阵右移一位(参考前文的 shift_right 函数),使得第 k 列对应相对偏移 i-j = k
  • ​切片​​:保留最后 seq_len_k 列,确保输出形状与 ac 一致 [seq_len_q, seq_len_k, batch, heads]
​(6) 最终分数​
return ac + bd
  • ​公式​​:( \text{Score}{i,j} = A{i,j} + B_{i,j} + C_{i,j} + D_{i,j} $
    • ac:内容相关性 + 全局查询偏置。
    • bd:相对位置相关性 + 相对位置偏置。

 3.完整代码

 # 计算相对位置注意力得分矩阵
    def get_scores(self,query,key):
        # 从预定义的相对位置矩阵中切片,获取与当前位置相关的部分,[P - seq_len_k : P + seq_len_q]
        key_pos_emb=self.key_pos_embeddings[self.P-key.shape[0]:self.P+query[0]]
        # 从预定义的相对位置偏置中获取与当前位置相关的部分,[P-seq_len_k, P+seq_len_q]
        key_pos_bias=self.key_pos_bias[self.P-key.shape[0]:self.P+query.shape[0]]
        # 扩展全局偏置的维度,从(heads,d_k)->(1,1,heads,d_k)
        query_pos_bias=self.query_pos_bias[None,None,:,:]
        # 计算内容相关分数,使用查询向量加上全局偏置与键向量点积,生成基础注意力分数矩阵
        ac=torch.einsum('ibhd,jbhd->ijbh',query+query_pos_bias,key)
        # 计算查询与相对位置的点积
        b=torch.einsum('ibhd,jhd->ijbh',query,key_pos_emb)
        # 扩展相对位置维度,(1,seq_len_q+seq_len_k,1,heads)
        d=key_pos_bias[None,:,None,:]
        # 将b和d相加得到相对位置分数,传入shift_right中使数据向右移一位,获得每个词的在规定范围(P)中的前后词的相对位置分数
        bd=shift_right(b+d)
        # 将seq_len_q切去,保留seq_len_k部分,保证与ac一致。(seq_len_q,seq_len_k,heads,d_k)
        bd=bd[:,-key.shape[0]:]
        # 将内容相关分数与位置相关分数相加得到相对位置注意力
        return ac+bd

 三、完整代码带测试

# 导入相关包
import torch
import torch.nn as nn
# inspect:打印数据的一个函数
from labml.logger import  inspect
# 导入多头注意力模块
from labml_nn.transformers import MultiHeadAttention

# 数据右移函数
def shift_right(x:torch.Tensor):  # x.shape=(seq_len,batch_size,heads,d_k)
    # 创建一个第一维(从0开始)和x不一样的全为0的tensor
    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])  # shape(seq_len,1,batch_size,heads,d_k)
    # inspect(zero_pad)
    # 将x和零矩阵按第一维结合
    x_padded = torch.cat([x, zero_pad], dim=1)  #shape(seq_len,batch_size+1,heads,d_k)
    # inspect(x_padded)
    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])  # shape(batch_size+1,seq_len,heads,d_k)
    # inspect(x_padded)
    x = x_padded[:-1].view_as(x)
    return x

# 重写多头注意力模块
class RelativeMultiHeadAttention(MultiHeadAttention):
    def __init__(self,heads,d_model,dropout_prob):
        super().__init__(heads,d_model,dropout_prob,bias=False)
        # 定义相对位置范围:4090
        self.P=2**12
        # 为每个头和每个可能的相对位置学习独立的嵌入向量(从-P到P),形状(2*P,heads,d_k)
        self.key_pos_embeddings=nn.Parameter(torch.zeros((self.P*2,heads,self.d_k)),requires_grad=True)
        # 为每个头和相对位置学习偏置项,增强位置感知(2*P,heads)
        self.key_pos_bias=nn.Parameter(torch.zeros((self.P*2,heads)),requires_grad=True)
        # 与查询位置无关的全局偏置,用于调整查询向量的表示(heads,d_k)
        self.query_pos_bias=nn.Parameter(torch.zeros((heads,self.d_k)),requires_grad=True)

    # 计算相对位置注意力得分矩阵
    def get_scores(self,query,key):
        # 从预定义的相对位置矩阵中切片,获取与当前位置相关的部分,[P - seq_len_k : P + seq_len_q]
        key_pos_emb=self.key_pos_embeddings[self.P-key.shape[0]:self.P+query[0]]
        # 从预定义的相对位置偏置中获取与当前位置相关的部分,[P-seq_len_k, P+seq_len_q]
        key_pos_bias=self.key_pos_bias[self.P-key.shape[0]:self.P+query.shape[0]]
        # 扩展全局偏置的维度,从(heads,d_k)->(1,1,heads,d_k)
        query_pos_bias=self.query_pos_bias[None,None,:,:]
        # 计算内容相关分数,使用查询向量加上全局偏置与键向量点积,生成基础注意力分数矩阵
        ac=torch.einsum('ibhd,jbhd->ijbh',query+query_pos_bias,key)
        # 计算查询与相对位置的点积
        b=torch.einsum('ibhd,jhd->ijbh',query,key_pos_emb)
        # 扩展相对位置维度,(1,seq_len_q+seq_len_k,1,heads)
        d=key_pos_bias[None,:,None,:]
        # 将b和d相加得到相对位置分数,传入shift_right中使数据向右移一位,获得每个词的在规定范围(P)中的前后词的相对位置分数
        bd=shift_right(b+d)
        # 将seq_len_q切去,保留seq_len_k部分,保证与ac一致。(seq_len_q,seq_len_k,heads,d_k)
        bd=bd[:,-key.shape[0]:]
        # 将内容相关分数与位置相关分数相加得到相对位置注意力
        """
        总结:
        1.通过公式,我们要计算先准备好三个参数:一个相对位置学习嵌入向量,一个相对位置学习篇偏置,一个全局偏置
        2.获取每个batch_size中的相对位置、偏置、全局偏置,
        3.通过公式:先将q与全局偏置相加后与k点积,得到内容相关分数
        4.计算位置注意力分数,
        5.将位置注意力分数和内容相关分数相加得到相对位置注意力分数
        """
        return ac+bd



def _test_shift_right():
   x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
   inspect(x)
   inspect(shift_right(x))

   x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
   inspect(x[:, :, 0, 0])
   inspect(shift_right(x)[:, :, 0, 0])

   x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
   inspect(x[:, :, 0, 0])
   inspect(shift_right(x)[:, :, 0, 0])

if __name__ == '__main__':

    _test_shift_right()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值