【HuggingFace Transformers】LlamaAttention源码解析

1. LlamaAttention 介绍

LlamaAttentionLLaMA 模型中负责实现自注意力机制的核心组件,其使用了多头自注意力(Multi-Head Self-Attention)机制,允许模型在不同的子空间中并行计算注意力,从而提高了对信息的表达能力。

1.1 多头注意力机制

多头注意力是一种在现代神经网络中广泛使用的机制,特别是在Transformer架构中。其结构如下:
在这里插入图片描述
图片参考来源:Attention Is All You Need

1.2 注意力的计算过程

  1. 计算每个头的Q、K、V:
    Q i = X W i Q , K i = X W i K , V i = X W i V Q_i=XW_i^{Q}, K_i=XW_i^{K}, V_i=XW_i^{V} Qi=XWiQ,Ki=XWiK,Vi=XWiV
  2. 计算每个头的注意力得分:
    s c o r e s i = Q i K i T d k scores_i=\frac{Q_iK_i^{T}}{\sqrt{d_k} } scoresi=dk QiKiT
    使用掩码mask(可选):
    s c o r e s i = s c o r e s i + m a s k scores_i =scores_i +mask scoresi=scoresi+mask
  3. 计算每个头的注意力权重并softmax归一化:
    a t t e n t i o n _ w e i g h t s i = s o f t m a x ( s c o r e s i ) attention\_weights_i=softmax(scores_i) attention_weightsi=softmax(scoresi)
  4. 计算每个头的加权和:
    a t t n _ o u t p u t i = a t t e n t i o n _ w e i g h t s i V i attn\_output_i=attention\_weights_iV_i attn_outputi=attention_weightsiVi
  5. 拼接所有注意力头并进行线性变换:
    a t t n _ o u t p u t = c o n c a t ( a t t n _ o u t p u t 1 , a t t n _ o u t p u t 2 , . . . , a t t n _ o u t p u t h ) attn\_output=concat(attn\_output_1, attn\_output_2,...,attn\_output_h) attn_output=concat(attn_output1,attn_output2,...,attn_outputh)
    f i n a l _ o u t p u t = a t t n _ o u t p u t W O final\_output=attn\_outputW^O final_output=attn_outputWO

2. LlamaAttention类 源码解析

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 15:15
import math
import torch
import torch.nn.functional as F

from typing import Optional, Tuple
from torch import nn
from transformers import LlamaConfig, Cache
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
from transformers.utils import logging

logger = logging.get_logger(__name__)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config  # 获取配置对象
        self.layer_idx = layer_idx  # 获取当前层的索引
        # 如果没有提供 layer_idx,会警告用户这可能在使用缓存时导致错误
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout  # 从配置中获取注意力dropout的概率
        self.hidden_size = config.hidden_size  # 获取隐藏层的维度大小
        self.num_heads = config.num_attention_heads  # 获取注意力头的数量
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)  # 获取每个注意力头的维度,如果没有指定,则默认等于隐藏层维度除以注意力头的数量
        self.num_key_value_heads = config.num_key_value_heads  # 键和值的头的数量
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  # 每组键/值头的数量
        self.max_position_embeddings = config.max_position_embeddings  # 最大位置嵌入数量
        self.rope_theta = config.rope_theta  # 旋转位置嵌入的角度参数
        self.is_causal = True  # 标志这个注意力机制是因果的,即只考虑当前位置及其之前的位置

        # 定义线性投影层
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

        # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
        # 在 v4.45 版本中移除(RoPE 在模型中计算,而不是在解码器层中)
        # 定义旋转位置嵌入(RoPE)层
        self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)
        position_ids: Optional[torch.LongTensor] = None,  # 位置id(可选)
        past_key_value: Optional[Cache] = None,  # 缓存键和值(可选)
        output_attentions: bool = False,  # 是否输出注意力权重
        use_cache: bool = False,  # 是否使用缓存
        cache_position: Optional[torch.LongTensor] = None,  # 缓存位置(可选)
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 位置嵌入,将在v4.45中作为必选项
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()  # 获取批次大小和序列长度
        # --------------------------------1. Q K V的线性计算(多处理器和单处理器)-------------------------------------#
        # 如果配置中启用了多处理器训练
        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp  # 计算每个处理器将处理的键值头和维度的切片大小
            # 将查询投影权重矩阵(self.q_proj.weight)按行拆分成多个切片,以便分配到不同的处理器上。每个切片的大小为 self.num_heads * self.head_dim 除以处理器数量
            query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  # 将键投影权重矩阵(self.k_proj.weight)按行拆分成多个切片,每个切片的大小为之前计算的 key_value_slicing。
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  # 将值投影权重矩阵(self.v_proj.weight)按行拆分成多个切片。每个切片的大小也为 key_value_slicing。

            # 多处理器环境下的注意力计算
            # 对每个处理器上的查询切片应用线性变换,将所有处理器的查询输出拼接在一起,形成一个完整的查询张量
            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)
            # 对每个处理器上的键切片应用线性变换,将所有处理器的键输出拼接在一起,形成一个完整的键张量
            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)
            # 对每个处理器上的值切片应用线性变换,将所有处理器的值输出拼接在一起,形成一个完整的值张量
            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:  # 单处理器,正常的注意力计算
            query_states = self.q_proj(hidden_states)  # 计算query
            key_states = self.k_proj(hidden_states)  # 计算key
            value_states = self.v_proj(hidden_states)  # 计算value

        # --------------------------------2. 调整Q K V的size, 适应多头注意力的维度格式-------------------------------------#
        # 调整query、key和value的形状,使它们符合多头注意力的格式
        # 具体维度变化为:[bsz, q_len, num_heads * head_dim] -> [bsz, q_len, num_heads, head_dim] -> [bsz, num_heads, q_len, head_dim]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # --------------------------------3. 为Q K添加位置编码-------------------------------------#
        # 如果没有提供位置嵌入,计算旋转位置嵌入;否则,直接使用位置嵌入
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings

        # apply_rotary_pos_emb 函数通过旋转位置编码对查询和键张量进行增强
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # 如果提供了缓存的键和值,更新缓存中的键和值
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # --------------------------------4. 再次调整Q K 的size, 适应注意力头的数量-------------------------------------#
        # 调整键(key)和值(value)张量的维度,以适应模型的注意力头数量。
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # --------------------------------5. 自注意力的计算-------------------------------------#
        # 5.1 计算查询和键的点积,并进行缩放
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        # 如果提供了注意力掩码,将因果掩码加到注意力权重上
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        # 5.2 计算softmax归一化后的注意力权重
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)

        # 5.3 将注意力权重与值张量进行矩阵乘法,以生成注意力输出
        attn_output = torch.matmul(attn_weights, value_states)

        # --------------------------------6. 自注意力size的调整和结果输出-------------------------------------#
        # 检查输出的size是否正确
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )
        # 调整注意力输出的size
        attn_output = attn_output.transpose(1, 2).contiguous()
        # 将多头输出拼接回原始维度
        attn_output = attn_output.reshape(bsz, q_len, -1)

        # 如果是多处理器训练
        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)  # 将 attn_output 张量沿着 dim=2 (特征维度) 拆分成多个片段,以适应分片训练的设置
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)  # 将输出投影权重 (o_proj.weight) 拆分成多个片段,以与 attn_output 对应
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])  # 对拆分的 attn_output 和权重片段进行线性变换并汇总
        else:
            attn_output = self.o_proj(attn_output)  # 通过线性投影层得到最终输出

        # 如果不需要输出注意力权重,将注意力权重设置为 None
        if not output_attentions:
            attn_weights = None

        # 返回最终的注意力输出、注意力权重和缓存的键值
        return attn_output, attn_weights, past_key_value

3. LlamaAttention类 的优化

3.1 LlamaFlashAttention2

LlamaFlashAttention2 是对 LlamaAttention 的一种优化实现,主要用于提高计算效率。它继承了 LlamaAttention 的所有权重和结构,但在前向传播过程中调用了 Flash Attention 的实现。Flash Attention 是一种高效的注意力计算方法,通过使用特定的优化技术(如顶部左对齐掩码或滑动窗口),能显著减少内存占用和计算时间。以下是 LlamaFlashAttention2 的几个关键特点:

  • 高效计算:利用 Flash Attention 提升计算效率,特别是在处理长序列时。
  • 动态掩码:支持变长序列和填充标记的处理,通过对填充标记进行适当的处理来提高精度。
  • 适配性:根据 Flash Attention 的版本调整参数,比如是否使用顶部左对齐掩码(use_top_left_mask)。

代码片段
源码地址:transformers/src/transformers/models/llama/modeling_llama.py

attn_output = _flash_attention_forward(
    query_states,
    key_states,
    value_states,
    attention_mask,
    q_len,
    position_ids=position_ids,
    dropout=dropout_rate,
    sliding_window=getattr(self, "sliding_window", None),
    use_top_left_mask=self._flash_attn_uses_top_left_mask,
    is_causal=self.is_causal,
)

3.2 LlamaSdpaAttention

LlamaSdpaAttention 是使用 torch.nn.functional.scaled_dot_product_attention 实现的注意力机制。它继承了 LlamaAttention,但在前向传播过程中适配了 SDPA(Static Dynamic Position-Aware Attention) API。以下是 LlamaSdpaAttention 的几个关键特点:

  • 使用 SDPA:通过 torch.nn.functional.scaled_dot_product_attention 实现,高效计算注意力分数。
  • 兼容性:适配了 SDPA API 的变化和特性,如处理填充标记和位置编码的不同方式。
  • 优化注意力计算:在处理包含填充标记的序列时,通过预处理和掩码调整来提高效率。

代码片段
源码地址:transformers/src/transformers/models/llama/modeling_llama.py

attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=causal_mask,
    dropout_p=self.attention_dropout if self.training else 0.0,
    is_causal=is_causal,
)

总的来说,LlamaFlashAttention2通过引入 Flash Attention 的优化技术,提高了长序列的计算效率和处理能力,特别适合需要高效处理大规模数据的场景。LlamaSdpaAttention结合 SDPA API 提供了高效的注意力计算,支持在处理填充标记和不同位置编码时的优化,适用于需要精确和高效位置感知的任务。

  • 11
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值