这篇堪称苏剑林老师的代表作了,简单来说RoPE就是乘上复数形式即可。也就是说两个二维向量的内积,等于把它们当复数看时,一个复数与另一个复数的共轭的乘积实部。也就是说
(
x
1
+
y
1
i
)
(
x
2
−
y
2
i
)
=
x
1
x
2
+
y
1
y
2
+
(
x
2
y
1
−
x
1
y
2
)
i
(x_1+y_1i)(x_2-y_2i)=x_1x_2+y_1y_2+(x_2y_1-x_1y_2)i
(x1+y1i)(x2−y2i)=x1x2+y1y2+(x2y1−x1y2)i,如果我们把
q
m
q_m
qm和
k
n
k_n
kn分别乘以
e
i
m
θ
e^{im\theta}
eimθ和
e
i
n
θ
e^{in\theta}
einθ,就会变成
q
m
e
i
m
θ
q_me^{im\theta}
qmeimθ和
k
n
e
i
n
θ
k_ne^{in\theta}
kneinθ,那么存在
<
q
m
e
i
m
θ
,
k
n
e
i
n
θ
>
=
R
e
[
q
m
k
n
∗
e
i
(
m
−
n
)
θ
)
]
<q_me^{im\theta},k_ne^{in\theta}>=Re[q_mk_n^*e^{i(m-n)\theta)}]
<qmeimθ,kneinθ>=Re[qmkn∗ei(m−n)θ)]
其中Re[]表示实数部分,
k
n
∗
k_n^*
kn∗表示共轭部分,相对位置m-n隐含在复数的共轭里,也就是上述表达式的右边;机器学习中的位置编码都是实数运算,也就是上述表达式的左边,所以RoPE实际上就是
q
m
q_m
qm乘以
e
i
m
θ
e^{im\theta}
eimθ,其中m表示第m个位置,i表示d维embedding中第i维度
插入一个代码实现:
import torch
from typing import Tuple
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
# 计算词向量元素两两分组之后,每组元素对应的旋转角度
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成 token 序列索引 t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.outer(t, freqs).float()
# torch.polar的文档, https://pytorch.org/docs/stable/generated/torch.polar.html
# torch.polar输入参数是abs和angle,abs所有值都一样,abs和angle的shape都一样
# torch.polar输入参数是abs和angle,则freqs_cis = abs*(cos(angle) + sin(angle)i)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# xq.shape = [batch_size, seq_len, dim]
# xq_.shape = [batch_size, seq_len, dim // 2, 2]
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
# 转为复数域, xq_.shape = [batch_size, seq_len, dim // 2]
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)
# 应用旋转操作,然后将结果转回实数域
# xq_out.shape = [batch_size, seq_len, dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) #从dim=2维度开始拍平
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
return xq_out.type_as(xq), xk_out.type_as(xk)
if __name__ == '__main__':
seq_len,dim=3,4
freqs_cis = precompute_freqs_cis(dim=dim, seq_len=seq_len, theta=10000.0)
xq = torch.rand(1, seq_len, dim)
xk = torch.rand(1, seq_len, dim)
res = apply_rotary_emb(xq, xk, freqs_cis)
# res的shape是1, seq_len, dim
'''
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wq = Linear(...)
self.wk = Linear(...)
self.wv = Linear(...)
self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
def forward(self, x: torch.Tensor):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(batch_size, seq_len, dim)
xk = xk.view(batch_size, seq_len, dim)
xv = xv.view(batch_size, seq_len, dim)
# attention 操作之前,应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# scores.shape = (bs, seqlen, seqlen)
scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)
# ......
'''
以下转载自https://kexue.fm/archives/8265
后面转载自https://kexue.fm/archives/8265: