【HuggingFace Transformers】LlamaRotaryEmbedding源码解析

1. LlamaRotaryEmbedding类 介绍

LLaMa模型中,LlamaRotaryEmbedding类实现了Rotary Position Embedding(RoPE)的方法。RoPE的核心思想是对位置编码应用旋转变换,使得不同位置之间的相对位置关系在编码过程中得到保留。这种旋转变换不仅能捕捉到序列的绝对位置,还能捕捉到相对位置,从而更好地处理长距离依赖。以下是旋转变换的过程:
在这里插入图片描述
图片来源:RoFormer: Enhanced Transformer with Rotary Position Embedding

2. 逆频率向量

频率向量在旋转位置编码 (Rotary Position Embedding, RoPE) 中是用于表示不同位置的频率信息的向量。这些向量帮助模型在处理序列数据时能够区分不同位置的相对关系。频率向量的计算和应用可以增强模型对位置的感知,从而改进模型的性能。频率向量通常由基数、向量的维度和位置索引生成。例如,频率向量可以表示为:
f r e q u e n c y = b a s e 2 i d frequency=base^{\frac{2i}{d}} frequency=based2i
则逆频率向量表示为:
i n v _ f r e q = 1 b a s e 2 i d inv\_freq=\frac{1}{base^{\frac{2i}{d}}} inv_freq=based2i1
其中:

  • i 是频率向量中的索引(例如第 i 个频率分量)。
  • d 是嵌入维度的大小。
  • base 是一个预定义的基数,通常为 10000

下面举个例子,说明一下逆频率向量的构建过程:

假设维度大小 dim = 8,基数 base = 10000
步骤如下:

  • 1.生成频率索引: 对于维度为 8,步长为 2,生成索引 [0, 2, 4, 6]。
  • 2.计算比例: 将索引除以维度大小得到比例 [0/8, 2/8, 4/8, 6/8] = [0, 0.25, 0.5, 0.75]。
  • 3.计算频率: 使用基数 10000,对比例应用幂运算,生成频率向量:
    f r e q u e n c y = [ 1000 0 0 , 1000 0 0.25 , 1000 0 0.5 , 1000 0 0.75 ] frequency=[10000^0,10000^{0.25} ,10000^{0.5} ,10000 ^{0.75} ] frequency=[100000,100000.25,100000.5,100000.75]
    计算结果约为:
    f r e q u e n c y ≈ [ 1 , 17.78 , 316.23 , 5623.41 ] frequency≈[1,17.78,316.23,5623.41] frequency[1,17.78,316.23,5623.41]
  • 4.取倒数生成逆频率向量。最后,对这些频率值取倒数,生成逆频率向量:
    i n v _ f r e q ≈ [ 1.0 , 0.056 , 0.00316 , 0.000178 ] {inv\_freq}≈[1.0,0.056,0.00316,0.000178] inv_freq[1.0,0.056,0.00316,0.000178]

参考代码:_compute_default_rope_parameters

def _compute_default_rope_parameters(
    config: Optional[PretrainedConfig] = None,
    device: Optional["torch.device"] = None,
    seq_len: Optional[int] = None,
    **rope_kwargs,
) -> Tuple["torch.Tensor", float]:
    """
    计算原始 RoPE 实现中的逆频率参数

    参数:
        config ([`~transformers.PretrainedConfig`]):
            模型配置,用于从中获取 RoPE 参数(如 base 和 dim)。
        device (`torch.device`):
            初始化逆频率参数时使用的设备(如 GPU 或 CPU)。
        seq_len (`int`, *optional*):
            当前序列长度。对于此类型的 RoPE 实现,该参数未被使用。
        rope_kwargs (`Dict`, *optional*):
            向后兼容以前的 RoPE 类实例化方式,该参数将在 v4.45 中移除。

    返回:
        包含 (`torch.Tensor`, `float`) 元组,其中包括 RoPE 嵌入的逆频率和应用于计算出的 cos/sin 的后处理缩放因子
        (此类型的 RoPE 中未使用)。
    """
    # 如果 config 参数不为 None 且 rope_kwargs 参数有值,抛出异常,二者是互斥的
    if config is not None and len(rope_kwargs) > 0:
        raise ValueError(
            "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
            f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
        )

    # 使用传入的 rope_kwargs 来初始化 base 和 dim 参数
    if len(rope_kwargs) > 0:
        base = rope_kwargs["base"]
        dim = rope_kwargs["dim"]
    # 使用 config 配置中的参数来初始化 base 和 dim 参数
    elif config is not None:
        base = config.rope_theta
        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
        head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        dim = int(head_dim * partial_rotary_factor)

    # RoPE 后处理的缩放因子,默认为 1.0(未在此类型 RoPE 中使用)
    attention_factor = 1.0  # Unused in this type of RoPE

    # 计算逆频率参数
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
    
    # 返回计算出的逆频率和缩放因子
    return inv_freq, attention_factor

3. LlamaRotaryEmbedding类 源码解析

算法核心为:

  • 计算inv_freq
  • 扩展inv_freq和position_ids
  • inv_freq_expanded与position_ids_expanded相乘并转置得到freqs
  • 拼接freqs
  • 计算cos值 和 sin值

3.1 transformers v4.44.2版

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
# @transformers.version: v4.44.2
import torch

from typing import Optional
from torch import nn
from transformers import LlamaConfig
from transformers.utils import logging
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS

logger = logging.get_logger(__name__)


class LlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,  # 嵌入维度
        max_position_embeddings=2048,  # 最大位置嵌入数
        base=10000,  # 基数,用于计算逆频率
        device=None,
        scaling_factor=1.0,  # 缩放因子
        rope_type="default",  # RoPE 类型
        config: Optional[LlamaConfig] = None,  # 可选的 Llama 配置
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used for BC
        # 移除下面的 if,此代码仅用于向后兼容(BC)
        self.rope_kwargs = {}  # 初始化存储 RoPE 参数的字典
        if config is None:
            logger.warning_once(
                "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.45"
            )
            # 如果没有传入配置对象,使用传入的参数进行初始化
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings  # 初始化缓存的最大序列长度
            self.original_max_seq_len = max_position_embeddings  # 保存原始最大序列长度
        else:
            # BC: "rope_type" was originally "type"
            # 如果传入了配置对象,向后兼容: "rope_type" 原来被称为 "type"
            if config.rope_scaling is not None:  # 如果配置中有 rope_scaling 参数
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))  # 获取 RoPE 类型
            else:
                self.rope_type = "default"  # 如果没有指定,使用默认类型
            self.max_seq_len_cached = config.max_position_embeddings  # 从配置中读取最大位置嵌入数
            self.original_max_seq_len = config.max_position_embeddings  # 保存原始最大序列长度

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]  # 根据 RoPE 类型选择初始化函数

        # 使用初始化函数计算逆频率 (inv_freq) 和注意力缩放 (attention_scaling)
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)

        self.register_buffer("inv_freq", inv_freq, persistent=False)  # 注册逆频率缓冲区(不会被持久化)
        self.original_inv_freq = self.inv_freq  # 保存原始逆频率

    def _dynamic_frequency_update(self, position_ids, device):
        """
        对于动态 RoPE 层,应在以下情况下重新计算 `inv_freq`:
        1 - 超过缓存的序列长度(允许缩放)
        2 - 当前序列长度在原始尺度内(避免对小序列失去精度)
        """
        seq_len = torch.max(position_ids) + 1  # 计算当前序列长度
        if seq_len > self.max_seq_len_cached:  # 序列长度增长时更新逆频率
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # 重新注册逆频率缓冲区  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len  # 更新缓存的最大序列长度

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # 如果序列长度变小,重置逆频率
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)  # 恢复原始逆频率
            self.max_seq_len_cached = self.original_max_seq_len  # 恢复缓存的最大序列长度

    @torch.no_grad()  # 禁用梯度计算以提高性能
    def forward(self, x, position_ids):
        """
        :param x: [bs, num_attention_heads, seq_len, head_size]
        :param position_ids: [bs, seq_len]
        """
        # 如果是动态 RoPE 类型,更新逆频率
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # 核心 RoPE 计算块
        # inv_freq: [dim/2] -> inv_freq_expanded: [batch_size, dim/2, 1]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)  # 扩展逆频率
        # position_ids: [batch_size, seq_len] -> position_ids_expanded: [batch_size, 1, seq_len]
        position_ids_expanded = position_ids[:, None, :].float()  # 扩展位置 ID

        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"

        with torch.autocast(device_type=device_type, enabled=False):  # 关闭自动混合精度
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)  # 计算频率。这里的 @ 符号是矩阵乘法的符号,表示对两个张量进行矩阵乘法运算。
            # freqs: [batch_size, seq_len, dim/2]
            emb = torch.cat((freqs, freqs), dim=-1)  # 拼接频率
            # emb: [batch_size, seq_len, dim]
            cos = emb.cos()  # 计算 cos 嵌入
            sin = emb.sin()  # 计算 sin 嵌入

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        # 高级 RoPE 类型应用后处理缩放因子,等价于缩放注意力
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        # 返回 cos 和 sin 嵌入
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

3.2 transformers v4.41.1版

如果v4.44.2看着有些复杂,可以参考v4.41.1

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
# @transformers.version: v4.41.1
import torch

from torch import nn


class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor  # 缩放因子
        self.dim = dim  # 嵌入维度
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入
        self.base = base  # 基数
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) # 计算逆频率向量
        self.register_buffer("inv_freq", inv_freq, persistent=False)  # 将逆频率向量注册到缓存
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings  # 最大序列长度缓存

    @torch.no_grad()
    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # -----------------------核心 RoPE 计算块----------------------- #
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  • 25
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值