Hugging face KV Cache代码解读

KV Cache的原理可以参考这篇文章,非常简洁明了:https://zhuanlan.zhihu.com/p/679249229

Hungging Face对于KV Cache的实现代码在transformers/models/gpt2/modeling_gpt2.py 文件中的GPT2Attention中实现,这里我们给出①仅仅 KV Cache的实现代码以及②GPT2Attention整体代码的解读。

KV Cache的实现代码:

query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

# 以下使用了 kv cache
# 将 key、query、value向量分割成多头
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# 如果存在先前的层状态,将其与当前层的 key、value 拼接
if layer_past is not None:
    past_key, past_value = layer_past
    key = torch.cat((past_key, key), dim=-2)
    value = torch.cat((past_value, value), dim=-2)

# 如果使用缓存,保存当前的 key、value
if use_cache:
    present = (key, value)
else:
    present = None

# 根据配置,选择标准或优化的注意力计算方法
if self.reorder_and_upcast_attn:
    attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

# 合并多头注意力的输出
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
# 通过投影层处理注意力输出
attn_output = self.c_proj(attn_output)
# 应用残差连接的dropout
attn_output = self.resid_dropout(attn_output)

# 准备输出,包括注意力的输出和可选的缓存
output = (attn_output, present)

GPT2Attention整体代码的解读:

# kv cache 在 hugging face 的 transformer库 的 transformers/models/gpt2/modeling_gpt2.py 文件中实现
"""
GPT2Attention实现包含以下几个部分:
    1D卷积层: 在OpenAI GPT和GPT-2中被用于处理序列数据
    动态剪枝功能,允许在训练过程中或之后移除不重要的注意力头。
    标准和优化的注意力计算方法 attention 和 `torch.baddbmm`进行更高效的批量矩阵乘法和加法。
    KV Cache: https://zhuanlan.zhihu.com/p/679249229
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast

import math
from typing import Optional, Tuple, Union, List, Set


class Conv1D(nn.Module):
    """
    实现了一个1D卷积层,该层的工作原理类似于线性层,但其权重是转置的。这种层在OpenAI GPT和GPT-2中被用于处理序列数据。
    
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.

    例子: 
    实例化Conv1D类:
        conv1d = Conv1D(nx=128, nf=64):创建了一个Conv1D的实例,输入特征数为128,输出特征数为64。
        此时,self.weight和self.bias参数被随机初始化。self.weight的形状为(128, 64),self.bias的形状为(64,)。
    输入张量:
        x = torch.randn(10, 20, 128):生成了一个形状为(10, 20, 128)的随机张量,代表有10个样本,每个样本是一个
        长度为20的序列,每个序列元素具有128个特征。
    通过Conv1D层传递输入:
        output = conv1d(x) 将输入x传递给Conv1D实例conv1d。在Conv1D的forward方法中,执行以下操作:
        size_out计算得到的新形状为(10, 20, 64),意味着输出将保持批量和序列长度不变,但特征维度变为64。
        torch.addmm利用self.bias和self.weight对输入x执行线性变换。首先,将x重塑为(200, 128)以匹配权重矩阵,然
        后执行矩阵乘法和加法操作,得到形状为(200, 64)的中间结果。
            x.view(-1, x.size(-1): (200, 128) 
            weight: (128, 64)
            x.view(-1, x.size(-1) * weight : (200, 64)
            然后将self.bias加到每一行上,最终输出也是(200, 64)的形状。
        最后,中间结果被重塑回(10, 20, 64),符合size_out的预期形状。
    """

    def __init__(self, nf, nx):
        super().__init__()
        # 初始化输出特征数和权重、偏置参数
        self.nf = nf    # 输出特征的数量
        self.weight = nn.Parameter(torch.empty(nx, nf)) # 权重参数,形状为 (输入特征数, 输出特征数)
        self.bias = nn.Parameter(torch.zeros(nf))       # 偏置参数,形状为 (输出特征数,)
        nn.init.normal_(self.weight, std=0.02)          # 权重初始化,标准差为0.02

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)           # 计算输出的尺寸,保持除最后一个维度外的其他维度不变,最后一个维度设置为输出特征数
        # 使用 addmm 进行矩阵乘法和加法运算,相当于执行了线性变换
        # x.view(-1, x.size(-1)) 将输入张量x重塑为二维张量,以便进行矩阵乘法
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        # 将结果重塑回原来的尺寸(加上输出特征数的维度)
        x = x.view(size_out)
        return x


def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
    """
    根据已经剪枝的头,找出可剪枝的头及其索引。

    参数:
    - heads (`List[int]`): 想要剪枝的头的索引列表。
    - n_heads (`int`): 模型中的头总数。
    - head_size (`int`): 每个头的大小。
    - already_pruned_heads (`Set[int]`): 已经被剪枝的头的集合。

    返回:
    - `Tuple[Set[int], torch.LongTensor]`: 考虑到`already_pruned_heads`后,需要被剪枝的头的索引集合,
      以及在层权重中需要保留的行/列的索引。

    例子:
        n_heads = 6                        模型中的头总数
        head_size = 2                      每个头的大小
        heads = [1, 2, 4]                  想要剪枝的头的索引列表
        already_pruned_heads = {0, 2, 3}   已经被剪枝的头的集合
    创建掩码: mask = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]
    更新要剪枝的头集合: heads = {1, 2, 4} - {0, 2, 3} = {1, 4}
    遍历想要剪枝的头的索引集合,更新掩码:
    由于头0、头2和头3已经被剪枝,我们只需要考虑其余的头。我们发现头1和头4需要被剪枝。
    注意,此时的"头2"和"头3"在already_pruned_heads中意味着它们已经不在考虑之列,所以我们不需要为它们调整索引。
    更新掩码,标记头1和头4将被剪枝:mask = [[1, 1], [0, 0], [1, 1], [1, 1], [0, 0], [1, 1]]
    mask_flat = [1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1]
    保留的索引是值为1的位置,对应于未被剪枝的头(包括已经因为already_pruned_heads被间接考虑的头)。

    """
    # 创建一个形状为(n_heads, head_size)的全1掩码张量
    mask = torch.ones(n_heads, head_size)
    # 将输入的头索引列表转换为集合,并从中移除已经剪枝的头
    heads = set(heads) - already_pruned_heads
    # 遍历想要剪枝的头的索引集合
    for head in heads:
        # 计算在指定头之前已经剪枝的头的数量,并据此调整头的索引
        head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
        # 在掩码中将对应的头的位置设为0,表示该头将被剪枝
        mask[head] = 0
    
    # 将掩码张量展平并获取等于1的元素的索引,表示这些位置对应的头不被剪枝
    mask = mask.view(-1).contiguous().eq(1)
    index: torch.LongTensor = torch.arange(len(mask))[mask].long()
    # 返回需要剪枝的头的索引集合和在层权重中保留的行/列的索引
    return heads, index


def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
    """
    prune_conv1d_layer函数的目的是对Conv1D层进行剪枝,使其只保留指定索引index中的条目。这个过程类似于在全
    连接层(线性层)中移除一些神经元或头。具体到Conv1D层,这通常用于在Transformer模型中移除不需要的注意力头。
    
    Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
    are transposed.

    Used to remove heads.

    Args:
        layer ([`~pytorch_utils.Conv1D`]): The layer to prune.          需要被剪枝的Conv1D层。
        index (`torch.LongTensor`): The indices to keep in the layer.   在层中需要保留的索引。
        dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices. 指定保留索引的维度,默认为1,对应于权重矩阵的列。
    Returns:
        [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
        经过剪枝后的新Conv1D层,保留了指定索引的条目,且requires_grad=True,允许梯度更新。

    例子:
    权重矩阵 (形状为 [nx, nf],假设 nx=3, nf=4,即3个输入特征和4个输出特征):
        [[0.2, 0.3, 0.4, 0.5],
         [0.1, 0.2, 0.3, 0.4],
         [0.5, 0.6, 0.7, 0.8]]
    偏置向量 (长度为 nf=4):
        [0.1, 0.2, 0.3, 0.4]
    剪枝索引 (index,假设我们只想保留第0和第2个输出特征,即索引为 [0, 2])
    
    剪枝过程
        定剪枝维度和索引:我们想要在输出特征上进行剪枝,因此dim=1。保留的索引是[0, 2]。
    选择并克隆权重和偏置:
        选择权重矩阵的子集(沿着输出特征维度),得到新的权重矩阵W:
            [[0.2, 0.4],
            [0.1, 0.3],
            [0.5, 0.7]]
    由于我们在dim=1维度上进行剪枝,对应的偏置向量b也需要更新:
        [0.1, 0.3]
    创建新的Conv1D层:
        新的层将有3个输入特征和2个输出特征(根据剪枝后保留的索引数量)。
    更新新层的权重和偏置:
        新层的权重矩阵和偏置向量分别被更新为剪枝后的W和b。

    """
    # 确保索引index和层的权重在相同的设备上(例如CPU或GPU)。
    index = index.to(layer.weight.device)
    # 根据dim和index选择权重的子集,并进行克隆和分离操作,以便创建权重的副本而不影响原始层的权重。
    # 对于偏置,如果dim是0,保持偏置不变;如果是1,则只保留索引对应的偏置条目。
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    # 根据剪枝后的尺寸创建一个新的Conv1D层,并确保新层在与原始层相同的设备上。
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    # 更新新层的权重和偏置:首先禁用新层权重和偏置的梯度更新,
    # 然后将剪枝后的权重和偏置复制到新层,
    # 最后重新启用梯度更新。这确保了新层可以在后续的训练中更新。
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer


class GPT2Attention(nn.Module):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__()

        # 注册一个因果掩码,用于自注意力中, 模型只能关注当前位置及之前的位置, 以避免信息泄露和未来信息访问
        max_positions = config.max_position_embeddings
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bools)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False
        )

        # 注册一个掩码偏置值,用于在计算softmax之前将某些位置的注意力分数设置为非常小的值(接近于负无穷),
        # 这样在应用softmax时,这些位置的权重接近于0,实现了“掩码”的效果。
        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)

        self.embed_dim = config.hidden_size                 # 模型的嵌入维度
        self.num_heads = config.num_attention_heads         # 注意力头的数量
        self.head_dim = self.embed_dim // self.num_heads    # 每个注意力头的维度
        self.split_size = self.embed_dim
        # 检查嵌入维度是否能被注意力头数量整除,这是多头注意力分割的前提
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.scale_attn_weights = config.scale_attn_weights # 是否对注意力分数进行缩放,以避免过大的分数导致梯度消失或爆炸
        self.is_cross_attention = is_cross_attention        # 标记是否为交叉注意力模式

        # 根据层索引逆比例缩放注意力权重,这个设置允许模型在更深的层使用更小的缩放因子,有助于控制深层网络中的梯度流动。
        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
        # 存储当前层的索引,这在模型中是静态的,但对于实现逐层缩放的功能是必要的
        self.layer_idx = layer_idx
        # 控制是否对注意力机制的内部计算进行重排序和上转型,以提高计算效率或减少内存消耗
        self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

        # 根据是否为交叉注意力, 初始化不同的卷积层来生成查询Q、键K、值V
        if self.is_cross_attention:
            # 交叉注意力模型下, 查询来自一个序列(如解码器的隐藏状态),键和值来自另一个序列(如编码器的隐藏状态)
            self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) # 生成 K 和 V
            self.q_attn = Conv1D(self.embed_dim, self.embed_dim)     # 生成 Q
        else:
            # 自注意力下,查询、键、值都来自于同一个序列
            self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        # 注意力计算后,对输出进行投影的卷积层
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

        # Dropout层,用于减少过拟合
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        # 记录被剪枝的头部,用于调整模型结构或减少计算量
        self.pruned_heads = set()  

    def prune_heads(self, heads):
        """动态剪枝功能,允许在训练过程中或之后移除不重要的注意力头。"""
        if len(heads) == 0:
            # 这里首先检查传入的heads列表是否为空。如果是空的,说明没有指定任何头进行剪枝,因此直接返回,不做任何处理。
            return
        
        # 调用find_pruneable_heads_and_indices函数,根据指定要剪枝的头的索引,找到这些头在模型中对应的位置。
        # 这个函数返回两个值:经过处理的heads列表(可能排除了一些已经被剪枝的头),以及这些头在卷积层权重中对应的索引。
        heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
        # 生成一个新的索引列表index_attn,这个列表包含了要剪枝的头在卷积层权重中对应的所有位置。由于c_attn卷积层
        # 同时负责生成查询(Q)、键(K)和值(V),所以需要将索引扩展到这三部分。
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

        # 调用 prune_conv1d_layer 函数,实际上进行剪枝操作。对于c_attn卷积层,根据index_attn移除对应的权重;
        # 对于c_proj卷积层,只需要根据原始索引移除权重,因为它只负责输出的投影
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

        # 这一步实质上是重新计算在剪枝后,如何按照剩余的头数分配输入向量的维度,确保每个头处理的维度大小保持一致,并且总和等于嵌入维度self.embed_dim(假设剪枝操作不会改变总的嵌入维度)。
        # self.split_size // self.num_heads: 这部分计算在剪枝之前,每个头处理的维度大小。由于self.split_size原本等于嵌入维度,这里通过除以头的总数self.num_heads得到每个头负责的维度大小。
        # self.num_heads - len(heads): 这部分计算剪枝后剩余的头的数量。从总的头数中减去被剪掉的头数len(heads),得到剪枝后剩余的头数。
        # self.split_size // self.num_heads) * (self.num_heads - len(heads)): 将每个头负责的维度大小乘以剪枝后剩余的头数,得到剪枝后的split_size。
        self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
        self.num_heads = self.num_heads - len(heads)
        self.pruned_heads = self.pruned_heads.union(heads)

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # 可能会用 kv cache, 所以 query 和 key 的长度不一定一样
        # 计算注意力分数,查询和键的点积
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            # 如果启用了注意力权重缩放(通常为了稳定训练过程),则通过除以值(value)维度的平方根来缩放注意力权重。这种缩放有助于控制梯度的大小,防止训练过程中出现梯度爆炸。
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )
        
        if self.scale_attn_by_inverse_layer_idx:
            # 如果启用了层级注意力缩放,则进一步通过层索引来调整注意力权重的缩放。对于模型较深的层,这可以提供额外的缩放,有助于调节不同层之间的贡献度。
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            """
            对于自注意力层,应用因果掩码(causal_mask)以防止位置i注意到i之后的位置,这对于保持自回归特性至关重要。掩码通过将未来位置的权重设置为一个非常小的值(接近负无穷)来实现
            通过一个具体的矩阵示例来说明这个过程
            假设我们有一个最大位置(max_positions)为4的模型
                [[1, 0, 0, 0],
                 [1, 1, 0, 0],
                 [1, 1, 1, 0],
                 [1, 1, 1, 1]]
            假设我们正在处理一个序列,其query_length为3,key_length也为3。
            key_length - query_length : key_length,结果是0:3
            causal_mask [[1, 0, 0],
                         [1, 1, 0],
                         [1, 1, 1]]
            注意力权重矩阵 (attn_weights),这个矩阵反映了序列中每个元素对其他元素的原始注意力得分。
            attn_weights (示例) =
                [[0.8, 0.1, 0.1],
                 [0.1, 0.7, 0.2],
                 [0.1, 0.2, 0.7]]

            mask_value 设为 -∞,应用掩码后的 attn_weights 可能看起来像这样
                [[0.8,  -∞,  -∞],
                 [0.1, 0.7,  -∞],
                 [0.1, 0.2, 0.7]]

            经过softmax操作后,attn_weights 可能会变为
                [[1.0,  0,    0],
                 [0.3, 0.7,  0],
                 [0.2, 0.3, 0.5]]

                 
            可能会用 kv cache, 所以 query 和 key 的长度不一定一样
            如果query_length和key_length不相等,会动态调整因果掩码的大小,以确保掩码适用于当前的序列长度,保持模型的自回归特性。
            从self.bias选择的部分 -> [:, :, 1:3, :3]
                得到的掩码 ->
                [[1, 0, 0],
                 [1, 1, 0]]
            然后,类似于之前的解释,这个掩码会被用来调整attn_weights,使得在注意力计算中,被掩码的位置(即掩码值为0的位置)的影响被忽略
            """
            # 如果是自注意力,它计算查询(query)和键(key)的序列长度。这里使用.size(-2)获取序列长度,因为注意力相关的张量通常具有形状[batch_size, seq_length, feature_dim]。
            query_length, key_length = query.size(-2), key.size(-2)
            # 使用预先注册的bias张量来获取因果掩码。bias是一个下三角矩阵, 通过切片操作调整掩码的形状,选取了一个与当前 query 和 key 大小相匹配的子矩阵, 使其与当前的查询和键的序列长度匹配。
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            # mask_value被设置为一个非常小的数,几乎等同于负无穷大,以便在应用softmax之前有效地“忽略”被掩码的位置。
            mask_value = torch.finfo(attn_weights.dtype).min
            # torch.full([], mask_value, dtype=attn_weights.dtype) 创建了一个填充了mask_value的张量,这里的[]表示张量的形状,空列表意味着创建一个标量张量(即只包含单个值的张量)
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            # 使用torch.where,将掩码为0的位置的attn_weights替换为 mask_value
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
        
        if attention_mask is not None:
            # 如果提供了额外的注意力掩码(如用于屏蔽填充位置),将其加到注意力权重上。这允许模型在计算注意力时忽略特定的位置。
            # attention_mask中每个需要被屏蔽的位置在掩码中设置一个非常大的负值(例如,-10000或负无穷),所以是 +
            attn_weights = attn_weights + attention_mask

        # 应用softmax函数对注意力权重进行归一化,确保所有权重的和为1,这样每个位置的贡献度被限制在0到1之间。
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # 确保注意力权重与值(value)的数据类型一致,然后应用dropout以减少过拟合。
        # 因为不同数据类型的张量直接进行运算可能会导致类型不匹配的错误。例如,一个float32类型的张量和一个float16类型的张量直接进行运算,可能会引发错误或警告,除非显式地进行类型转换
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights
    
    def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
        """
        优化的注意力计算方法,特别适用于处理大规模数据或在混合精度训练中提高性能。与`_attn`函数相似,但使用`torch.baddbmm`进行更高效的批量矩阵乘法和加法。
        torch.baddbmm是PyTorch中的一个函数,用于执行批量的矩阵乘法和加法。它的全名是batch matrix-matrix product of matrices with addition(带加法的批量矩阵乘法)。这个函数的基本操作可以表示为:
        output[i] = beta * input[i] + alpha * (batch1[i] @ batch2[i])
        output= β x batch1+ a x (batch2 x batch3)
            batch1:加法的第一个矩阵,它的形状是 (b,n,m)。
            batch2:乘法中的第一个批量矩阵,形状为 (b,n,p)。
            batch3:乘法中的第二个批量矩阵,形状为 (b,p,m)。
            a (alpha)和 β(beta)是缩放因子,分别应用于乘法和加法结果。
            output:操作的结果,形状也是 (b,n,m)。
        在这个方法中,torch.baddbmm被用于计算批量的点积(乘法)操作,并可选地将结果与另一个张量(在这种情况下是初始化为零的张量)相加。这里主要是利用它进行点积计算,以获得注意力机制中的原始注意力权重。

        """
        # 首先,获取批次大小(bsz)、注意力头数(num_heads)、查询序列长度(q_seq_len)和键维度(dk)。这些尺寸用于后续的矩阵重塑和计算。
        bsz, num_heads, q_seq_len, dk = query.size()
        _, _, k_seq_len, _ = key.size()

        # 预分配注意力权重矩阵,准备进行baddbmm操作。这里创建了一个空的张量,大小为(bsz * num_heads, q_seq_len, k_seq_len),即每个头对每个查询与所有键的注意力权重
        attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32)

        # 计算缩放因子,首先是基于值向量维度的根号倒数,然后根据层索引进行进一步的调整。这两个操作旨在提高模型的稳定性和性能。
        scale_factor = 1.0
        if self.scale_attn_weights:
            scale_factor /= float(value.size(-1)) ** 0.5
        if self.scale_attn_by_inverse_layer_idx:
            scale_factor /= float(self.layer_idx + 1)
        
        with autocast(enable=False):
            # 这里,query和key首先被重塑或转置以满足矩阵乘法的需求(b,n),(n,m)。然后,torch.baddbmm执行以下操作:
            # batch2 (q.float()) 和 batch3 (k.float()) 进行矩阵乘法。这部分对应于计算查询和键之间的点积,得到注意力分数。
            # beta 设置为0,意味着attn_weights(加法的第一个矩阵)在这个操作中实际上不起作用,因为任何数乘以0都是0。因此,加法部分可以忽略。
            # alpha 是缩放因子,用于调整点积的结果,以保持数值稳定性。在这个上下文中,它可能用于缩放点积结果以防止梯度消失或爆炸,特别是当值(value)的维度很大时。
            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
            attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
        
        # 下面部分 和 _attn部分一样
        if not self.is_cross_attention:
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if attn_weights.dtype != torch.float32:
            raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights
    
    def _split_heads(self, tensor, num_heads, attn_head_size):
        """
        将模型的输入张量重新组织,以支持多头注意力计算。

        参数:
        - tensor: 输入张量,形状为 (batch_size, seq_length, hidden_size)。
        - num_heads: 注意力头的数量。
        - attn_head_size: 每个注意力头的特征维度。

        返回:
        - 重新组织后的张量,形状为 (batch_size, num_heads, seq_length, attn_head_size)。
        """
        # 计算新的形状,将 hidden_size 维度拆分为 num_heads 和 attn_head_size 两个维度。
        # 原始张量的形状是(batch_size, seq_length, hidden_size),tensor.size()[:-1]将会得到(batch_size, seq_length)。
        # + (num_heads, attn_head_size):
        # 这里,+操作符用于元组的连接。它将(batch_size, seq_length)与(num_heads, attn_head_size)连接起来,形成新的形状。
        # torch.size()[:-1] + (num_heads, attn_head_size) 得到的新形状是(batch_size, seq_length, num_heads, attn_head_size)
        new_shape = torch.size()[:-1] + (num_heads, attn_head_size)
        tensor = tensor.view(new_shape)
        # 重新排列张量的维度,使其符合多头计算的需求。
        return tensor.permute(0, 2, 1, 3)   # (batch, head, seq_length, head_features)
    
    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """
        将多头注意力计算的输出合并为单个张量,用于后续处理。

        参数:
        - tensor: 输入张量,形状为 (batch_size, num_heads, seq_length, attn_head_size)。
        - num_heads: 注意力头的数量。
        - attn_head_size: 每个注意力头的特征维度。

        返回:
        - 合并后的张量,形状为 (batch_size, seq_length, hidden_size)。
        """
        # 重新排列张量的维度,准备合并注意力头。
        tensor = tensor.permute(0, 2, 1, 3).contiguous()  # 调整后的形状为 (batch_size, seq_length, num_heads, attn_head_size)
        # 计算合并后的新形状,将 num_heads 和 attn_head_size 两个维度合并为 hidden_size 维度。
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        # 重塑张量的形状,完成注意力头的合并。
        return tensor.view(new_shape)
    
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],              # 当前层的输入张量。
        layer_past: Optional[Tuple[torch.Tensor]] = None,               # 之前层的输出,用于键值对缓存。kv cache
        attention_mask: Optional[torch.FloatTensor] = None,             # 可选的掩码张量,用于屏蔽某些输入的注意力
        head_mask: Optional[torch.FloatTensor] = None,                  # 可选的掩码张量,用于屏蔽某些注意力头
        encoder_hidden_states: Optional[torch.Tensor] = None,           # 编码器的隐藏状态,仅在交叉注意力中使用
        encoder_attention_mask: Optional[torch.FloatTensor] = None,     # 编码器的注意力掩码,仅在交叉注意力中使用
        use_cache: Optional[bool] = False,                              # 是否使用键值对缓存
        output_attentions: Optional[bool] = False,                      # 是否输出注意力权重
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        # 判断是否为交叉注意力模式
        if encoder_hidden_states is not None:
            # 如果是交叉注意力,确保定义了`q_attn`权重
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                )

            # 生成查询向量
            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            # 使用编码器的注意力掩码
            attention_mask = encoder_attention_mask
        else:
            # 对于自注意力,直接从隐藏状态生成key、query、value
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        # 以下使用了 kv cache
        # 将 key、query、value向量分割成多头
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)
        
        # 如果存在先前的层状态,将其与当前层的 key、value 拼接
        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        # 如果使用缓存,保存当前的 key、value
        if use_cache:
            present = (key, value)
        else:
            present = None

        # 根据配置,选择标准或优化的注意力计算方法
        if self.reorder_and_upcast_attn:
            attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
        else:
            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        # 合并多头注意力的输出
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        # 通过投影层处理注意力输出
        attn_output = self.c_proj(attn_output)
        # 应用残差连接的dropout
        attn_output = self.resid_dropout(attn_output)

        # 准备输出,包括注意力的输出和可选的缓存
        output = (attn_output, present)
        # 如果请求输出注意力权重,加入到输出中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # 返回输出,可能包括注意力输出、缓存的键值对和注意力权重

代码注释较长,请耐心查看。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值