旋转式位置编码Rotary Position Embedding(RoPE)

这篇堪称苏剑林老师的代表作了,简单来说RoPE就是乘上复数形式即可。也就是说两个二维向量的内积,等于把它们当复数看时,一个复数与另一个复数的共轭的乘积实部。也就是说 ( x 1 + y 1 i ) ( x 2 − y 2 i ) = x 1 x 2 + y 1 y 2 + ( x 2 y 1 − x 1 y 2 ) i (x_1+y_1i)(x_2-y_2i)=x_1x_2+y_1y_2+(x_2y_1-x_1y_2)i (x1+y1i)(x2y2i)=x1x2+y1y2+(x2y1x1y2)i,如果我们把 q m q_m qm k n k_n kn分别乘以 e i m θ e^{im\theta} eimθ e i n θ e^{in\theta} einθ,就会变成 q m e i m θ q_me^{im\theta} qmeimθ k n e i n θ k_ne^{in\theta} kneinθ,那么存在
< q m e i m θ , k n e i n θ > = R e [ q m k n ∗ e i ( m − n ) θ ) ] <q_me^{im\theta},k_ne^{in\theta}>=Re[q_mk_n^*e^{i(m-n)\theta)}] <qmeimθ,kneinθ>=Re[qmknei(mn)θ)]
其中Re[]表示实数部分, k n ∗ k_n^* kn表示共轭部分,相对位置m-n隐含在复数的共轭里,也就是上述表达式的右边;机器学习中的位置编码都是实数运算,也就是上述表达式的左边,所以RoPE实际上就是 q m q_m qm乘以 e i m θ e^{im\theta} eimθ,其中m表示第m个位置,i表示d维embedding中第i维度
在这里插入图片描述

插入一个代码实现:

import torch
from typing import Tuple

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # torch.polar的文档, https://pytorch.org/docs/stable/generated/torch.polar.html
    # torch.polar输入参数是abs和angle,abs所有值都一样,abs和angle的shape都一样
    # torch.polar输入参数是abs和angle,则freqs_cis = abs*(cos(angle) + sin(angle)i)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    # 转为复数域,  xq_.shape = [batch_size, seq_len, dim // 2]
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    # 应用旋转操作,然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) #从dim=2维度开始拍平
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

if __name__ == '__main__':
    seq_len,dim=3,4
    freqs_cis = precompute_freqs_cis(dim=dim, seq_len=seq_len, theta=10000.0)
    xq = torch.rand(1, seq_len, dim)
    xk = torch.rand(1, seq_len, dim)
    res = apply_rotary_emb(xq, xk, freqs_cis)
    # res的shape是1, seq_len, dim
    
'''
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.wq = Linear(...)
        self.wk = Linear(...)
        self.wv = Linear(...)
        
        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

    def forward(self, x: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(batch_size, seq_len, dim)
        xk = xk.view(batch_size, seq_len, dim)
        xv = xv.view(batch_size, seq_len, dim)

        # attention 操作之前,应用旋转位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        
        # scores.shape = (bs, seqlen, seqlen)
        scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
        scores = F.softmax(scores.float(), dim=-1)
        output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)
  # ......
'''

以下转载自https://kexue.fm/archives/8265
在这里插入图片描述
后面转载自https://kexue.fm/archives/8265:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值