【大模型】MHA,MQA,GQA及KV Cache详解

为什么需要KV Cache?

在自回归模型中(Autoregressive Models),模型会逐个生成文本的每个token,这个过程可能比较慢,因为模型一次只能生成一个token,而且每次新的预测都依赖于之前的上下文。这意味着,要预测第1000个token,你需要用到前999个token的信息,这通常涉及到对这些token的表示进行一系列矩阵乘法运算。而要预测第1001个token,你不仅需要前999个token的信息,还要加上第1000个token的信息。

KV Cache是在推理阶段 减少重复计算,优化推理效率的方法。在计算Attention的时候,KV Cache把所有的K和V缓存下来,然后生成next token时就可以复用这些KV。

一、Multi-Head Attention (MHA)

在进入KV Cache的正题前,我们需要对Multi-Head Attention (MHA)即标准的多头注意力机制有所了解。这里,我们首先回顾最标准的 Transformer 中的 Self-Attention 的计算及MHA部分的介绍:Transformer 模型详解

二、KV Cache

2.1 产生原因

参考博客: 用数学推导的方式理解KV Cache

对于Self-Attention 的计算有了一定了解后,下面我们一步一步用公式推导来说明KV Cache到底是如何产生的。

这里Q和K转置的矩阵乘实际上是在做向量点乘,把QKV都写成分块形式会比较清晰:

Q和K转置的矩阵乘可以写成如下形式:

注意这里小写的 k k k q q q 都是长度为 d k d_k dk 的向量. 这里不需要赋予 Q Q Q K K K 任何含义,把它们当成 T T T 个长度为 d k d_k dk 的向量即可。可以看出, Q K T QK^T QKT q i . k j q_i.k_j qi.kj 的点乘运算等价。

在序列的第t个位置,attention运算结果为:

可以看到,如果没有KV Cache的话,每次都得重新计算出所有的 k j ∗ v j k_j*v_j kjvj ,这就存在大量的重复计算。因此我们需要缓存所有的KV,这样每次新的token进来,只需要算最后一个 k t , j ∗ v t , j k_{t,j}*v_{t, j} kt,jvt,j ,而不需要算之前的KV,避免大量的重复计算,增加推理效率。

2.2 因果掩码(Causal Mask)

LLM的推理过程是自回归的,输入前 t个token,输出第 t+1个token. 然后输出的token接到输入序列最后,输入t+1个token,输出第t+2个token,如此反复。

在进行自回归推理的过程中,模型无法得到未来token的信息,也就无法知道整个序列所有位置的KV.。因此需要引入因果掩码(Causal Mask),把未来的信息掩盖掉。

具体来说,所谓的“未来”信息指的是所有的 q t ∗ k j ( j > t ) q_t * k_j (j>t) qtkj(j>t)

加入因果掩码后,序列第t个位置的attention结果正如我们所设想的变为:

2.3 推理流程

LLM的推理流程分为两部分,分别是prefill(预填充)和decode(解码)。

(1)prefill(预填充)

prefill过程是LLM解析整个prompt,并生成第一个token的过程。假设prompt的长度为T个token,对应T个query。接着进行attention运算:
在这里插入图片描述
得到的attention结果经过layernorm和feed forward层,再经过一系列处理输出第一个token.

在prefill过程中,LLM会把计算得到的 k 1 , . . . , k t ; v 1 , . . . , v t {k_1, ...,k_t; v_1,...,v_t} k1,...,kt;v1,...,vt 存入KV cache中。

(2)decode

decode过程是LLM根据上一个token输出下一个token的过程。prefill结束后,LLM输出了一个token. 将这个token作为LLM的输入,对应1个query。进行attention运算时,Q退化为向量,得到的attention结果也是向量:

在这里插入图片描述

为什么没有Q Cache?

只要理解了decode阶段的计算过程,这个问题很好解释。

  • 在推理解码阶段,输出token是一个一个生成的,transformer只需要计算出 O i O_i Oi
  • Attention机制中的第 i i i 个输出 O i O_i Oi 只和第 i i i 个query有关,和其他query无关,所以query没有必要缓存。
  • Attention的输出 O i O_i Oi 的计算和完整的K和V有关,而K、V的历史值只和历史的 O O O 有关,和当前的 O O O 无关。那么就可以通过缓存历史的K、V,避免重复计算。

三、缓解KV Cache

首先给出变量的含义:

  • d d d 代表输入维度(input dim)
  • n h n_h nh 代表头数(head数)
  • d h d_h dh 代表每个头的维度
  • h t h_t ht 代表输入的第 t 个向量
  • l l l 代表 transformer 的层数

对于标准的MHA而言,对于每一个token,KV Cache占用的缓存的大小为 2 n h d h l 2n_hd_hl 2nhdhl

由于层数 l l l 是固定的,所以KV Cache的缓解主要就在于如何减小 n h d h n_hd_h nhdh。因此,Grouped-Query Attention (GQA)、Multi-Query Attention (MQA) 、MLA (Multi-head Latent Attention) 等技术被提出,下面我们看看这几个工作是如何减少缓存占用的。

3.1 Multi-Query Attention (MQA) 、Grouped-Query Attention (GQA)

在这里插入图片描述

  • Multi-Query Attention(MQA):所有的查询头(Query Heads)共享相同的 Key 和 Value。即削减了 FLOPs,也降低了 Cache,并且压缩了频繁矩阵拼接的 I/O 耗时。

  • Grouped-Query Attention(GQA) :是 Multi-Query Attention(MQA) 的改进版,它通过在 多个查询头(Query Heads)之间共享 Key 和 Value,在 Multi-Head Attention(MHA) 和 MQA 之间找到了一种折中方案。GQA 旨在在 推理速度模型质量 之间取得更好的平衡,减少 MQA 带来的模型质量下降问题,同时仍然保留比 MHA 更快的推理速度。DeepSeekV1 67B、LLaMa2 70B和LLaMa3全系列都用了GQA。

我们再来计算下对于MQA和GQA而言,每个token需要占用的缓存量:

  • MQA:缓存大小为 2 d h l 2d_hl 2dhl ,相比于MHA的 2 n h d h l 2n_hd_hl 2nhdhl 而言大大减少,但是性能也会差一些。
  • GQA: 缓存大小为 2 n g d h l 2n_gd_hl 2ngdhl,其中 n g n_g ng 代表head的分组数。

3.2 MLA (Multi-head Latent Attention)

MLA的提出也是为了解决在推理时KV Cache占据空间过大的问题,首先给出原论文中对于MLA的示意图:

在这里插入图片描述

MLA的完整公式如下:

这里关于MLA的具体技术细节不再详细介绍了,可以参考博客:【大模型】DeepSeek核心技术之MLA (Multi-head Latent Attention)

参考资料

### MHAGQA 和 MLA 的区别及应用场合 #### 多头注意力机制(Multi-Head Attention, MHA) 多头注意力机制允许模型在不同的表示子空间中并行关注不同位置的信息。每个头独立操作,最终结果通过拼接各头的结果来获得更丰富的特征表达[^1]。 ```python import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads # 定义线性变换层 self.query_linear = nn.Linear(embed_size, embed_size) self.key_linear = nn.Linear(embed_size, embed_size) self.value_linear = nn.Linear(embed_size, embed_size) def forward(self, query, key, value): batch_size = query.size(0) # 对输入进行线性变换 Q = self.query_linear(query) K = self.key_linear(key) V = self.value_linear(value) # 将嵌入维度分割成多个头 Q = Q.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) K = K.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) V = V.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) # 计算注意力分数并加权求和 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5) attention_weights = F.softmax(scores, dim=-1) output = torch.matmul(attention_weights, V) # 合并头部并将结果传递给下一个线性层 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_size) return output ``` #### 组化查询注意力机制(Grouped Query Attention, GQA) 为了减少计算量,GQA引入了查询分组的概念,即某些查询可以共享相同的键和值矩阵。这减少了重复计算的数量,在大规模数据集上尤其有效[^3]。 ```python class GroupedQueryAttention(nn.Module): def __init__(self, embed_size, num_groups, heads_per_group): super(GroupedQueryAttention, self).__init__() self.embed_size = embed_size self.num_groups = num_groups self.heads_per_group = heads_per_group # 初始化参数... def forward(self, queries, keys, values): # 实现GQA的具体逻辑... pass ``` #### 压缩键值注意力机制(Compressed Key/Value Attention, MLA) MLA进一步优化了资源利用效率,通过对键和值向量实施低秩近似压缩处理,从而显著降低了存储开销以及前向传播过程中的运算复杂度。 $$K_{\text{compressed}} = U_K \cdot S_K \cdot V_K^T$$ $$V_{\text{compressed}} = U_V \cdot S_V \cdot V_V^T$$ ```python from scipy.linalg import svd def compress_matrix(matrix, rank): u, s, vh = svd(matrix) compressed = np.dot(u[:, :rank], np.dot(np.diag(s[:rank]), vh[:rank, :])) return compressed class CompressedKeyValueAttention(nn.Module): def __init__(self, embed_size, compression_rank): super(CompressedKeyValueAttention, self).__init__() self.compress_key = lambda k: compress_matrix(k, compression_rank) self.compress_value = lambda v: compress_matrix(v, compression_rank) # 其他初始化... def forward(self, queries, keys, values): compressed_keys = self.compress_key(keys) compressed_values = self.compress_value(values) # 使用压缩后的keys和values继续执行标准的注意力机制... pass ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值