大模型中的 KV Cache

是否大家在部署大模型的时候,总会遇到显存不足的问题呢?明明我的设备能存下模型参数啊!凭什么超内存了了呢!?

其实,这是 KV Cache 在作祟

KV Cache 是一种大模型推理加速的方法,该方法通过缓存 Attention 中的 K, V 来实现推理优化

1.背景:大模型推理的冗余计算

观察 Only-Decoder 架构的大模型生成的过程。

假设模型只是一层 Self-Attention,用户输入“中国的首都”,模型续写得到”是北京”。

生成过程如下:

  1. 将“中国的首都”输入模型,得到每个 token 的注意力得分(绿色)。使用 “首都” 的注意力表示,预测得到下一个 token 为 ”是”(省略计算 logits 的过程)。

  2. 将 “是” 拼接到原来的输入,得到 “中国的首都是” ,将其输入模型,得到注意力表示,使用 “是” 的注意力表示,预测得到下一个 token 是 “北”。

  3. 将“北”拼接到原来的输入,依此类推,预测得到“京”,最终得到“中国的首都是北京”

    在这里插入图片描述

在每一步生成中,仅使用输入序列中的最后一个 token 的注意力表示 即可预测出下一个 token。但是模型还是并行计算了所有 token 的注意力表示,其中产生了大量的冗余计算(包含 QKV 映射,attention 计算等),并且输入长度越长,产生的冗余计算量越大。

例如:

  1. 在第一步中,仅需要使用“首都”的注意力表示,即可预测到 “是”,但模型仍然会并行计算出 “中国”、“的” 这两个 token 的注意力表示。
  2. 在第二步仅需要“是”就可以预测到“北”,但模型会并行计算之前所有的。

2.KV Cache

在这里插入图片描述

在推理阶段,当输入长度为 n,我们仅需要使用 b n b^n bn 就可以预测下一个 token,但模型会并行计算出 b 1 , b 2 , … , b n − 1 b^1, b^2,…,b^{n-1} b1,b2,,bn1 ,这里产生了大量的冗余计算。

实际上 b n b^n bn 可以直接通过公式 b n = ∑ i = 1 n s o f t m a x ( q n ⋅ k i ) v i b^n=\sum^n_{i=1}softmax(q^n·k^i)v^i bn=i=1nsoftmax(qnki)vi 计算得出,即 b n b^n bn 只与 q n q^n qn、所有的 k , v k, v k,v 有关。

KV Cache 的本质是空间换时间。他将历史输入的 token 中的 k 和 v 缓存下来,避免每一步生成都重新计算历史 token 的 k 和 v 以及 注意力表示 b 1 , b 2 , … , b n − 1 b^1, b^2,…,b^{n-1} b1,b2,,bn1 ,而是直接通过 b n = ∑ i = 1 n s o f t m a x ( q n ⋅ k i ) v i b^n=\sum^n_{i=1}softmax(q^n·k^i)v^i bn=i=1nsoftmax(qnki)vi 的方式计算得到 b n b^n bn,然后预测下一个 token。

举例,用户输入“中国的首都”,模型续写得到的输出为“是北京”,KV Cache每一步的计算过程如下。

第一步生成时,缓存 K, V 均为空,输入为 “中国的首都”,模型将按照常规的方法进行并行计算:

  1. 并行计算得到每个 token 对应的 k, v,以及注意力表示 b 1 , b 2 , b 3 b^1,b^2,b^3 b1,b2,b3
  2. 使用 b 3 b^3 b3 预测下一个 token,得到 “是”
  3. 更新缓存,令 K = [ k 1 , k 2 , k 3 ] K=[k^1,k^2,k^3] K=[k1,k2,k3] V = [ v 1 , v 2 , v 3 ] V=[v^1,v^2,v^3] V=[v1,v2,v3]

在这里插入图片描述

第二步生成时,计算流程如下:

  1. 仅将“是”输入模型,对其词向量进行映射,得到 q 4 , k 4 , v 4 q^4,k^4,v^4 q4,k4,v4
  2. 更新缓存,令 K = [ k 1 , k 2 , k 3 , k 4 ] K=[k^1,k^2,k^3,k^4] K=[k1,k2,k3,k4] V = [ v 1 , v 2 , v 3 , v 4 ] V=[v^1,v^2,v^3,v^4] V=[v1,v2,v3,v4]
  3. 计算 b 4 = ∑ i = 1 4 s o f t m a x ( q 4 ⋅ k i ) v i b^4=\sum^4_{i=1}softmax(q^4·k^i)v^i b4=i=14softmax(q4ki)vi,预测下一个 token,得到 “京”

在这里插入图片描述

上述过程中,只有第一步生成的时候,模型需要计算所有的 token 的 k , v k, v k,v,并缓存下来。此后的每一步,仅需要计算当前 token 的 q , k , v q, k, v q,k,v ,更新缓存 K , V K,V KV ,然后使用 q , K , V q, K, V q,K,V 算出当前 token 的注意力表示,预测下一个 token


KV Cache 是以空间换时间,当输入序列非常长的时候,需要缓存非常多 k 和 v,显存占用非常大。

为了缓解该问题,可以使用 MQA、GQA、Page Attention 等技术。

<think>嗯,用户问的是大模型KV Cache的实现方式,需要详细解释。首先,我得回忆一下自己对KV Cache的理解。KV Cache主要是在Transformer的自注意力机制中用来缓存键(Key)和值(Value)的,这样可以避免在生成每个token时重新计算之前的KV,节省计算资源,提高推理速度。不过具体实现细节可能比较复杂,需要分步骤说明。 首先,用户可能需要知道KV Cache的基本原理,比如预填充阶段和解码阶段的区别。预填充阶段处理整个prompt序列,生成对应的KV缓存,而解码阶段逐步生成新token,并更新缓存。这部分需要结合引用里的内容,比如引用[1]提到的两个阶段,还有引用[3]提到的更新KV Cache的过程。 然后,实现细节可能包括数据结构设计,比如如何存储KV缓存。通常是用张量来保存,每个层都有自己的KV缓存。引用[3]提到形状的变化,比如batch_size、num_heads、seq_len和维度,这部分需要详细说明。 另外,内存管理也很重要,用户可能关心内存优化方法,比如引用[2]提到的量化技术,或者动态缓存和分块存储。还有计算流程,比如在预填充阶段并行计算,解码阶段增量更新,引用[4]里的预填充阶段解释可能需要提到。 还要注意实现中的关键技术点,比如注意力掩码调整,防止越界访问,以及内存预分配策略。可能用户还会关心不同框架的具体实现差异,比如Hugging Face和vLLM的不同处理方式,这可以引用相关例子说明。 最后,优化措施也是重点,比如量化、分块、压缩,这些在引用里都有提到,需要整合到回答中。同时,生成相关问题的时候,要围绕实现方式、优化、应用场景等展开,确保问题有针对性。 总结一下,回答的结构应该是先介绍KV Cache的作用和原理,再分阶段详细说明预填充和解码,接着讲具体实现中的数据结构、内存管理、计算流程,关键技术点,不同框架的实现,优化措施,最后给出相关问题。需要确保引用正确的文献,比如引用[1][3][4]来支持各个部分的内容,同时保持语言的中文和格式要求。</think>### 大模型 KV Cache 实现方法详解 KV Cache 是大型语言模型(LLM)推理优化的核心技术,通过缓存历史 token 的键(Key)和值(Value)张量,避免重复计算,提升生成效率。以下是其实现方式的详细拆解: --- #### 一、KV Cache 的核心原理 1. **自注意力机制中的关键角色** Transformer 的自注意力计算公式为: $$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ 其中 $Q,K,V$ 分别表示查询(Query)、键(Key)、值(Value)矩阵。KV Cache 的作用是缓存历史 token 的 $K$ 和 $V$ 矩阵,避免每次生成新 token 时重新计算整个序列的注意力权重[^1][^4]。 2. **空间换时间的本质** 通过牺牲存储空间(缓存历史 KV 值)换取计算时间(减少重复计算),尤其对长序列生成任务(如文本续写)效果显著[^1]。 --- #### 二、KV Cache 的实现步骤 ##### 1. **预填充阶段(Prefill Phase)** - **输入处理**:将完整的输入序列(如 prompt)一次性输入模型。 - **并行计算**:为每个 Transformer 层生成初始的 KV 缓存: ```python # 示例:PyTorch 实现逻辑 k_cache = torch.zeros(batch_size, num_heads, max_seq_len, head_dim) v_cache = torch.zeros_like(k_cache) for i in range(input_seq_len): k_cache[:, :, i, :] = current_layer.key(input_tokens[i]) v_cache[:, :, i, :] = current_layer.value(input_tokens[i]) ``` - **缓存形状**:形状为 $(batch\_size, num\_heads, seq\_len, head\_dim)$,其中 $head\_dim$ 是注意力头的维度[^3][^4]。 ##### 2. **解码阶段(Decode Phase)** - **增量更新**:每生成一个新 token,仅计算当前 token 的 $K_t$ 和 $V_t$,并追加到缓存末尾: ```python # 当前 token 的 KV 计算 k_t = current_layer.key(new_token) v_t = current_layer.value(new_token) # 更新缓存 k_cache = torch.cat([k_cache, k_t], dim=2) v_cache = torch.cat([v_cache, v_t], dim=2) ``` - **注意力计算优化**:仅需计算新 token 的 $Q$ 与历史 $K$、$V$ 的注意力权重,避免全序列重复计算[^3][^4]。 --- #### 三、关键技术实现细节 1. **内存管理策略** - **动态缓存**:根据序列长度动态扩展缓存空间(如使用链表或可扩展张量)。 - **分块存储**:将长序列拆分为固定大小的块(如 512 token/块),减少内存碎片[^2]。 2. **计算优化** - **掩码调整**:在自注意力中应用因果掩码(Causal Mask),确保新 token 仅关注历史信息。 - **内存预分配**:预先分配最大序列长度的缓存空间(如 4096 tokens),避免频繁内存操作[^3]。 3. **框架级实现差异** | 框架/库 | 实现特点 | |---------------|--------------------------------------------------------------------------| | Hugging Face | 使用 `past_key_values` 参数管理缓存,支持动态序列扩展 | | vLLM | 采用 PagedAttention 技术,将 KV Cache 分页存储以支持超长序列生成[^2] | --- #### 四、优化措施与挑战 1. **内存压缩技术** - **量化(Quantization)**:将 KV Cache 从 FP16 压缩至 INT8,内存占用减少 50%[^2]。 - **稀疏化**:通过剪枝移除冗余的 KV 条目。 2. **性能瓶颈** - **显存带宽限制**:KV Cache 频繁读写可能导致显存带宽成为瓶颈。 - **长序列稳定性**:缓存累积可能引发数值溢出或注意力权重衰减[^4]。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值