Transformer 中的 KV Cache:原理、作用与应用
在大规模 Transformer 模型(如 GPT-3、LLaMA、ChatGPT)中,KV Cache(Key-Value 缓存) 是优化推理效率的重要机制,特别是在自回归生成任务(如文本生成、对话系统)中起到了至关重要的作用。它能有效减少重复计算,提高推理速度,同时降低计算成本。那么,KV Cache 到底是什么?为什么需要 KV Cache?它的工作原理是什么? 本文将深入剖析 Transformer 中的 KV Cache。
动图解释的很好的文章:Transformers KV Caching Explained
1. 什么是 KV Cache?
在 Transformer 模型中,KV Cache(Key-Value Cache,键值缓存) 指的是 自回归推理过程中缓存的 Key 和 Value,以避免重复计算。它主要用于 自注意力机制(Self-Attention),在 解码(Decoder) 过程中缓存之前计算过的 Key(键)和 Value(值),从而加速生成。
在没有 KV Cache 的情况下,每次生成一个新的 token 时,模型都需要重新计算所有 token 的 Query(查询)、Key(键)、Value(值),导致计算冗余。而使用 KV Cache 后,只需要计算新 token 的 Query,并与已缓存的 KV 进行注意力计算,大幅提高推理速度。
2. 为什么需要 KV Cache?
在 Transformer 解码(文本生成) 过程中,每生成一个新的 token,模型都会重新计算整个输入序列的注意力。这种 O(N²) 的计算复杂度 在序列较长时会导致推理速度大幅下降,主要有以下问题:
-
计算冗余:
- 在每个时间步,模型都会重新计算所有已生成 token 的 Key 和 Value。
- 但这些 Key 和 Value 在过去的计算中已经得到,理论上可以复用,避免重复计算。
-
推理速度瓶颈:
- 对于长文本生成,每个 token 都要计算完整的
QK^T
矩阵,计算复杂度为 O(N²),导致推理速度下降。 - 随着上下文长度增长,计算量呈二次增长,使得大模型推理成本昂贵。
- 对于长文本生成,每个 token 都要计算完整的
-
加速解码:
- 通过缓存 Key 和 Value,解码器在生成新 token 时只需要计算当前 token 的 Query 并与缓存的 KV 交互,计算复杂度从 O(N²) 降到 O(N),显著加速推理。
3. Transformer 中 KV Cache 的工作原理
3.1 自注意力机制回顾
在标准 Transformer 自注意力(Self-Attention)中,每个 token 通过 QKV 机制 进行计算:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dkQKT)V
其中:
Q
(Query):查询向量K
(Key):键向量V
(Value):值向量- d k d_k dk:缩放因子
计算步骤:
- 计算
Q
,K
,V
矩阵 - 计算
QK^T
相关性矩阵 - 对
QK^T
进行 softmax 归一化 - 加权求和得到输出
在训练阶段,所有 token 可以并行计算 QK^T
,但在 自回归推理(Auto-Regressive Decoding)时,每个 token 只能依赖前面已经生成的 token,这就是 KV Cache 需要优化的地方。
3.2 KV Cache 的优化方式
在自回归解码(生成文本)过程中,我们的目标是 避免重复计算 K
和 V
。KV Cache 采用增量缓存(Incremental Cache)的方式存储 Key 和 Value:
- 第一步(初始计算):
- 计算第一个 token 的
K_1, V_1
,存入缓存。
- 计算第一个 token 的
- 第二步(生成新 token 时):
- 计算新 token 的 Query
Q_n
,但 不计算之前 token 的 Key 和 Value,直接从缓存读取K_1, K_2, ..., K_{n-1}
和V_1, V_2, ..., V_{n-1}
。 - 只计算
Q_nK^T
,然后进行注意力计算。
- 计算新 token 的 Query
最终,存储 K
和 V
,仅计算 Q
,可以大幅减少计算量:
Attention ( Q n , [ K 1 , K 2 , . . . , K n − 1 ] , [ V 1 , V 2 , . . . , V n − 1 ] ) \text{Attention}(Q_n, [K_1, K_2, ..., K_{n-1}], [V_1, V_2, ..., V_{n-1}]) Attention(Qn,[K1,K2,...,Kn−1],[V1,V2,...,Vn−1])
这样,计算复杂度由 O(N²) 降至 O(N),随着序列长度增长,计算成本大幅降低。
4. KV Cache 的实际应用
4.1 大规模语言模型(LLMs)
在 GPT-3、LLaMA、ChatGPT 等 自回归语言模型中,KV Cache 是必备优化策略,用于减少推理计算量。例如:
- OpenAI GPT 系列:使用 KV Cache 加速生成,提高 token 生成速度。
- Meta LLaMA:优化 KV Cache 以支持更长上下文的高效推理。
- Google PaLM / Gemini:在多模态生成中采用 KV Cache 来优化长文本处理。
4.2 推理引擎优化
在大规模推理框架(如 TensorRT、ONNX Runtime、vLLM)中,KV Cache 已成为默认优化技术:
- TensorRT-LLM:专门优化 KV Cache 存储和访问,减少显存占用。
- vLLM:基于 Paged KV Cache 技术,支持超长文本高效生成。
5. KV Cache 的挑战与未来优化
5.1 显存占用
- KV Cache 需要存储所有已生成 token 的 Key 和 Value,对长序列推理时 显存占用大,尤其是在多头注意力(Multi-Head Attention)中,每一层都需要缓存
K, V
。 - 解决方案:
- Paged KV Cache:采用分块存储,减少 GPU 显存压力。
- FlashAttention:优化 GPU 访问 KV Cache 的方式,降低显存消耗。
5.2 动态 KV Cache
- 传统 KV Cache 需要线性存储
K, V
,但在多轮对话、长文本生成中,需要删除无用缓存,避免显存爆炸。 - 解决方案:
- Sliding Window KV Cache:仅保留最近的
N
个 token 进行 KV 计算。 - 精细化 KV 复用:减少长序列存储需求,提高缓存利用率。
- Sliding Window KV Cache:仅保留最近的
6. 总结
KV Cache 是 Transformer 推理中的关键优化技术,它通过 缓存 Key 和 Value,避免重复计算,显著提升了文本生成任务的推理效率。它的核心优势包括:
✅ 计算复杂度从 O(N²) 降至 O(N),大幅提升推理速度
✅ 减少 GPU 计算负担,支持更长的上下文处理
✅ 广泛应用于 GPT、LLaMA、PaLM 等大型语言模型
如何通过实际例子模拟 KV Cache,并详细说明如何降低计算复杂度到 O(N)
在 Transformer 解码过程中,自回归(Auto-Regressive)生成是常见的文本生成方式,如 GPT-3、LLaMA 等大语言模型(LLMs)在推理时的工作模式。KV Cache(Key-Value Cache) 的核心思想是 缓存 Key 和 Value,避免重复计算,从而将计算复杂度 从 O(N²) 降低到 O(N)。
本篇文章将使用 Python + NumPy 模拟 KV Cache 机制,并详细解释如何优化计算复杂度。
1. KV Cache 在 Transformer 解码中的作用
1.1 Transformer 解码过程
在 Transformer 生成模型(如 GPT)中,每一步生成新的 token 时,都需要计算自注意力(Self-Attention):
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V
Attention(Q,K,V)=softmax(dkQKT)V
如果没有 KV Cache,每个新 token 生成时,必须重新计算整个序列的 K
和 V
,导致:
- 计算复杂度:每次计算
QK^T
,复杂度为 O(N²) - 存储开销:所有历史
K, V
重新计算,造成计算冗余
1.2 使用 KV Cache
- 计算第一个 token 时,存储
K_1, V_1
- 计算第二个 token 时,只计算
Q_2
,复用K_1, V_1
- 计算第三个 token 时,只计算
Q_3
,复用K_1, K_2, V_1, V_2
- 以此类推……
这样,每次生成新 token 只需计算当前 token 的 Q,并与已有的 KV Cache 进行计算,避免重复计算 K
和 V
,从而将计算复杂度 从 O(N²) 降低到 O(N)。
2. 代码实现 KV Cache
接下来,我们使用 Python + NumPy 来模拟 Transformer 中 KV Cache 的工作流程。
2.1 不使用 KV Cache 的方法(O(N²))
import numpy as np
# 假设 d_k = 4(注意力头维度)
d_k = 4
N = 5 # 生成 5 个 token
# 随机初始化 Q, K, V
np.random.seed(42)
Q = np.random.rand(N, d_k) # (N, d_k)
K = np.random.rand(N, d_k) # (N, d_k)
V = np.random.rand(N, d_k) # (N, d_k)
# 计算 Attention (O(N²) 复杂度)
attention_scores = np.dot(Q, K.T) / np.sqrt(d_k)
attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=1, keepdims=True)
output = np.dot(attention_weights, V)
print("Attention Output:\n", output)
分析
- 这里
np.dot(Q, K.T)
计算QK^T
,复杂度 O(N²) np.dot(attention_weights, V)
计算注意力加权和,复杂度 O(N²)- 计算量随序列长度增长呈二次增长,导致推理变慢
2.2 使用 KV Cache(O(N) 复杂度)
优化方案:
- 存储
K
和V
,避免重复计算 - 每次新生成 token 只计算
Q_nK^T
,而不重新计算K
- 计算量随 N 线性增长,即 O(N) 复杂度
class KVCacheTransformer:
def __init__(self, d_k):
self.d_k = d_k
self.K_cache = [] # 存储 Key
self.V_cache = [] # 存储 Value
def add_to_cache(self, K_new, V_new):
"""将新的 Key, Value 添加到缓存"""
self.K_cache.append(K_new)
self.V_cache.append(V_new)
def compute_attention(self, Q_new):
"""计算注意力(仅使用缓存的 KV)"""
if not self.K_cache:
return np.zeros(self.d_k) # 空缓存返回零向量
K_cache = np.vstack(self.K_cache) # (N, d_k)
V_cache = np.vstack(self.V_cache) # (N, d_k)
# 计算 Q_newK^T / sqrt(d_k)
attention_scores = np.dot(Q_new, K_cache.T) / np.sqrt(self.d_k)
attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), keepdims=True)
# 计算 Attention Output
return np.dot(attention_weights, V_cache)
# 初始化 Transformer
transformer = KVCacheTransformer(d_k)
# 逐步生成 5 个 token
for i in range(5):
# 生成新的 Q, K, V
Q_new = np.random.rand(d_k)
K_new = np.random.rand(d_k)
V_new = np.random.rand(d_k)
# 添加 K, V 到缓存
transformer.add_to_cache(K_new, V_new)
# 计算 Attention
output = transformer.compute_attention(Q_new)
print(f"Token {i+1} - Attention Output: {output}")
3. 如何降低计算复杂度到 O(N)?
3.1 对比复杂度
方法 | 计算步骤 | 计算复杂度 |
---|---|---|
无 KV Cache | 计算 QK^T 和 softmax(QK^T)V | O(N²) |
使用 KV Cache | 计算 Q_nK^T (缓存 K, V ) | O(N) |
3.2 关键优化点
-
缓存
K, V
避免重复计算:- 在
KVCacheTransformer
类中,我们只存储K, V
,然后每次新 token 仅计算Q_nK^T
。
- 在
-
改进
QK^T
计算方式:- 传统方法:
np.dot(Q, K.T)
,每次计算所有历史K
- KV Cache 方法:
np.dot(Q_new, K_cache.T)
,只计算当前 Q 和缓存的 K(计算量从 O(N²) 降到 O(N))
- 传统方法:
4. 结论
✅ KV Cache 的核心作用
- 通过 存储历史 Key-Value,减少不必要的计算,使得自回归 Transformer 推理速度更快
- 计算复杂度从 O(N²) 降低到 O(N),适用于 LLM 推理加速
✅ 适用场景
- GPT-3/LLaMA/ChatGPT 推理
- 长文本生成
- 语音识别、机器翻译
✅ 未来优化方向
- Paged KV Cache(分块存储,减少显存占用)
- Sliding Window KV Cache(仅保留最近 N 个 token)
- FlashAttention(GPU 优化,加速
QK^T
计算)
在 Transformer 领域,KV Cache 是不可或缺的技术,随着深度学习的发展,它将继续在大规模推理优化中发挥核心作用!🚀
训练时需要 KV Cache 吗?还是只有推理时需要?
简短回答:
- 训练时 不需要 KV Cache,通常不会使用它。
- 推理时 需要 KV Cache,以加速生成,减少计算冗余。
但更深入地理解 KV Cache 的作用,需要从 Transformer 训练和推理的计算模式入手。
1. 训练与推理的计算方式
1.1 训练(Training)
在 Transformer 训练阶段,所有 token 是并行计算的,因此不需要 KV Cache:
- 输入是完整的序列,可以一次性计算所有
Q, K, V
,自注意力计算QK^T
也可以一次性完成。 - 计算复杂度 O(N²):
- 对于序列长度
N
,计算QK^T
需要O(N²)
。 - 但 GPU 可以并行计算所有 token,因此在训练阶段这个计算量是可接受的。
- 对于序列长度
为什么不需要 KV Cache?
- 训练时,模型会同时计算整个序列的
Q, K, V
,不会像推理时一样一个一个地处理 token,因此没有计算冗余,不需要缓存 Key-Value。
1.2 推理(Inference)
推理(文本生成)采用 自回归(Auto-Regressive) 方式:
- 每次生成一个新的 token,只能依赖之前生成的 token 计算新的
QK^T
。 - 计算复杂度 O(N²) → O(N):
- 无 KV Cache:每个 token 生成时都重新计算
K, V
,计算复杂度 O(N²)。 - 有 KV Cache:只计算当前 token 的
Q
,并复用之前缓存的K, V
,计算复杂度降为 O(N)。
- 无 KV Cache:每个 token 生成时都重新计算
为什么需要 KV Cache?
- 避免重复计算:每个 token 只需计算一次
K, V
,减少计算冗余。- 提高推理速度:计算量从 O(N²) 降至 O(N),在长序列推理时效果尤为显著。
2. 训练与推理对比
阶段 | 计算方式 | 是否需要 KV Cache | 计算复杂度 |
---|---|---|---|
训练(Training) | 所有 token 并行计算 | ❌ 不需要 | O(N²)(但可并行化) |
推理(Inference) | 自回归生成,每次计算一个 token | ✅ 需要 | O(N)(优化后) |
3. 为什么训练不使用 KV Cache?
3.1 并行计算减少冗余
在训练过程中,所有 token 计算 Q, K, V
是一次性完成的:
- 无须逐步缓存 Key-Value,因为所有 token 同时可用。
- GPU 并行计算
QK^T
,不需要缓存来减少计算量。
3.2 GPU 计算优势
现代 GPU 可以高效执行矩阵乘法,比如:
QK_T = torch.matmul(Q, K.T) / math.sqrt(d_k)
- 并行计算:训练时,
QK^T
计算的复杂度是 O(N²),但可以使用 GPU 并行加速。 - 批量处理(Batching):一次处理多个序列,减少显存开销。
3.3 训练时使用 Mask 确保自回归
在 Transformer 解码器的训练过程中,会使用 注意力 Mask(Causal Mask):
- 目的是确保 第 i 个 token 只能看到前 i 个 token,模拟自回归推理过程。
- 但这个 Mask 不影响计算方式,仍然是一次性计算
Q, K, V
,不会缓存 Key-Value。
4. 为什么推理时必须使用 KV Cache?
4.1 推理是自回归生成
推理时,每个 token 依赖于前面生成的 token,因此无法并行计算:
- 如果不缓存 KV:
- 计算
QK^T
需要重新计算所有K, V
,计算复杂度 O(N²)。
- 计算
- 如果使用 KV Cache:
- 仅计算当前
Q_n
,复用之前缓存的K, V
,计算复杂度 O(N)。
- 仅计算当前
4.2 KV Cache 让计算量降低到 O(N)
无 KV Cache(O(N²))
for i in range(N): # 每个生成的 token
K_all = compute_K(all_previous_tokens)
V_all = compute_V(all_previous_tokens)
attention = softmax(Q_n @ K_all.T) @ V_all
有 KV Cache(O(N))
for i in range(N): # 每个生成的 token
K_cache.append(compute_K(Q_n))
V_cache.append(compute_V(Q_n))
attention = softmax(Q_n @ K_cache.T) @ V_cache
只计算当前
Q_n
,直接复用K_cache, V_cache
,大幅减少计算量。
5. 结论
✅ 训练时不需要 KV Cache
- 训练时,所有 token 计算是并行完成的,不需要缓存
K, V
。 - 计算复杂度 O(N²),但由于 GPU 并行计算,仍然可以高效执行。
✅ 推理时必须使用 KV Cache
- 推理采用 自回归生成,每个 token 依赖前面已生成的 token。
- 通过 KV Cache 复用历史
K, V
,将计算复杂度从 O(N²) 降低到 O(N),大幅提高推理效率。
希望这篇文章能帮助你更好地理解 为什么 KV Cache 只在推理时使用,而在训练时不需要!🚀
后记
2025年2月23日05点56分于上海,在GPT4o大模型辅助下完成。