Llama改进之——RoPE旋转位置编码

引言

旋转位置编码(Rotary Position Embedding, RoPE)将绝对相对位置依赖纳入自注意力机制中,以增强Transformer架构的性能。目前很火的大模型LLaMA、QWen等都应用了旋转位置编码。

之前在[论文笔记]ROFORMER中对旋转位置编码的原始论文进行了解析,重点推导了旋转位置编码的公式,本文侧重实现,同时尽量简化数学上的推理,详细推理可见最后的参考文章。

复数与极坐标

复数由两个部分组成:实部(real part)和虚部(imaginary part)。实部就是一个普通的数字,可以是零、正数或负数。虚部是另一个实数与 i i i

### LLaMA RoPE 实现 在Transformer架构中,RoPE(Rotary Position Embedding)是一种有效的位置编码方法。对于LLaMA模型而言,RoPE被用来赋予模型理解序列中元素相对位置的能力。 下面是基于PyTorch的一个简化版的RoPE实现方式: ```python import torch import math def apply_rotary_pos_emb(q, k, sin, cos): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) class RotaryEmbedding(torch.nn.Module): def __init__(self, d_model, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float().to(device=device) / d_model)) self.register_buffer('inv_freq', inv_freq) # Build sinusoidal/base embeddings. t = torch.arange(max_position_embeddings, device=device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.cos_cached = emb.cos()[None, None, :, :] self.sin_cached = emb.sin()[None, None, :, :] def forward(self, query, key): seq_len = query.shape[-2] return ( self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...] ) ``` 上述代码定义了一个`RotaryEmbedding`类来生成旋转位置编码,并提供了一个辅助函数`apply_rotary_pos_emb()`用于将这些编码应用于查询(query)和键(key)[^1]。 为了使该模块能够正常工作,在调用时需传入相应的query与key张量作为参数。此过程通常发生在自注意力机制之前,确保每个token都能获得其对应的位置信息增强表示形式[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

愤怒的可乐

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值