本文记录大模型推理阶段 KV Cache 的原理及显存占用情况。
Self-Attention 与 KV Cache
如图,当新生成的 token x 进到模型计算 Attention 时,先分别乘上参数矩阵 W q W_q Wq、 W k W_k Wk、 W v W_v Wv 得到向量 q,以及矩阵 K、V。然后根据下面公式计算当前 token 跟前面 tokens 的注意力权重(本文为了简化,不考虑多头 MHA)。
自回归生成过程中,K和V矩阵并没有太大变化,比如下图中 cold 这个词对应了 K 的某一列和 V 的某一行,算完就放那里不再变了。
轮到生成 chill 这个词时,其实只需要在原始 K 矩阵追加一列,原始 V 矩阵追加一行,而没必要每生成一个 token 都重新计算一遍 K、V 矩阵,这便是 KV Cache
的意义。
因此在推理的时候,不用每次传入前面全部 token 序列的 embedding,而只需传入 KV Cache 以及当前 token x 的 embedding。Transformer 在算完当前 token x 的 Attention 之后,会把新的 K’ 和 V’ 更新到 GPU 显存中。 左图中 Masked Multi Self Attention
这块也是唯一和前面序列有交互的模块,其他模块(比如 Layer Norm、FFN、位置编码等)都不涉及跟已生成 token 的交互。
KV Cache 显存占用分析
KV Cache 显存计算方式如下: