基本原理
两句话说明 KV Cache 的原理:
- 由于Attention中的mask机制(前面token的Q,看到不到后面token的K和V),当前生成的token只依赖于 当前输入的token以及之前所有token的KV矩阵
- 因此 可以通过 缓存 注意力机制中的键值对,在自回归生成过程中减少重复计算,加速大模型的推理过程
矩阵运算
要想理解 KV Cache,需要先了解一下矩阵运算的一个基本性质:分块矩阵计算法则,也就是两个矩阵相乘可以拆分成 列向量与行向量运算的和,在张量并行中也有用到:
[ X 1 X 2 ] × [ A 1 A 2 ] = [ X 1 A 1 + X 2 A 2 ] = Y \left[\begin{array}{ll} X_1 & X_2 \end{array}\right] \times\left[\begin{array}{l} A_1 \\ A_2 \end{array}\right]=\left[X_1 A_1+X_2 A_2\right]=Y [X1X2]×[A1A2]=[X1A1+X2A2]=Y
在这个性质的基础上,此时如果矩阵 X X X 是一个下三角矩阵,就有:
[ X 1 , 1 0 ⋯ 0 X 2 , 1 X 2 , 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ X m , 1 X m , 2 ⋯ X m , n ] ⋅ [ Y 1 , 1 Y 1 , 2 ⋯ Y 1 , p Y 2 , 1 Y 2 , 2 ⋯ Y 2 , p ⋮ ⋮ ⋱ ⋮ Y n , 1 Y n , 2 ⋯ Y n , p ] = [ X 1 , 1 Y ⃗ 1 X 2 , 1 Y ⃗ 1 + X 2 , 2 Y ⃗ 2 ⋮ X m , 1 Y ⃗ 1 + X m , 2 Y ⃗ 2 + ⋯ + X m , n Y ⃗ n ] \begin{aligned} & {\left[\begin{array}{cccc} X_{1,1} & 0 & \cdots & 0 \\ X_{2,1} & X_{2,2} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ X_{m, 1} & X_{m, 2} & \cdots & X_{m, n} \end{array}\right] \cdot\left[\begin{array}{cccc} Y_{1,1} & Y_{1,2} & \cdots & Y_{1, p} \\ Y_{2,1} & Y_{2,2} & \cdots & Y_{2, p} \\ \vdots & \vdots & \ddots & \vdots \\ Y_{n, 1} & Y_{n, 2} & \cdots & Y_{n, p} \end{array}\right]} \\ & =\left[\begin{array}{l} X_{1,1} \vec{Y}_1 \\ X_{2,1} \vec{Y}_1+X_{2,2} \vec{Y}_2 \\ \vdots \\ X_{m, 1} \vec{Y}_1+X_{m, 2} \vec{Y}_2+\cdots+X_{m, n} \vec{Y}_n \end{array}\right] \end{aligned} X1,1X2,1⋮Xm,10X2,2⋮Xm,2⋯⋯⋱⋯00⋮Xm,n ⋅ Y1,1Y2,1⋮Yn,1Y1,2Y2,2⋮Yn,2⋯⋯⋱⋯Y1,pY2,p⋮Yn,p = X1,1Y1X2,1Y1+X2,2Y2⋮Xm,1Y1+Xm,2Y2+⋯+Xm,nYn
可以看到,结果矩阵的第 k k k 行只用到了矩阵 X X X 的 第 k k k 个行向量。所以 X X X 不需要进行全部的矩阵乘法,每一步只取第 k k k 个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。
如果把矩阵 X X X 认为是 Attention 计算得到的 QK score,矩阵 Y Y Y 认为是 V V V,上述计算流程基本上就是 KV Cache 的框架了。这里要注意一点,因为 现在大模型基本都是Decoder-only的架构 ,自回归生成的过程中,当前token在做attention计算时是看不到后面的,后面的内容都被 mask 了。所以 KV Cache 只有 Decoder-only架构才有,Encoder-only 比如 BERT模型 是没有的
KV Cache原理
下面用一个具体的例子来说明,KV Cache假设当前文本为 我要
, 要 大模型 输出的内容为学习AI
第一次计算
模型第一次计算,生成学
时的过程如下:
为了方便演示,忽略scale项 d \sqrt{d} d,最终Attention的计算公式如下,(softmaxed 表示已经按行进行了softmax):
A t t step 1 ( Q , K , V ) = softmax ( [ Q 1 K 1 T − ∞ Q 2 K 1 T Q 2 K 2 T ] ) [ V 1 → V 2 → ] = ( [ softmaxed ( Q 1 K 1 T ) 0 softmaxed ( Q 2 K 1 T ) softmaxed ( Q 2 K 2 T ) ] ) [ V 1 → V 2 → ] = ( [ softmaxed ( Q 1 K 1 T ) × V 1 → softmaxed ( Q 2 K 1 T ) × V 1 → + softmaxed ( Q 2 K 2 T ) × V 2 → ] ) \begin{aligned} Att_{\text {step } 1}(Q, K, V)&=\operatorname{softmax}\left(\left[\begin{array}{cc} Q_1 K_1^T & -\infty \\ Q_2 K_1^T & Q_2 K_2^T \end{array}\right]\right)\left[\begin{array}{l} \overrightarrow{V_1} \\ \overrightarrow{V_2} \end{array}\right] \\ & =\left(\left[\begin{array}{cc} \operatorname{softmaxed}\left(Q_1 K_1^T\right) & 0 \\ \operatorname{softmaxed}\left(Q_2 K_1^T\right) & \operatorname{softmaxed}\left(Q_2 K_2^T\right) \end{array}\right]\right)\left[\begin{array}{l} \overrightarrow{V_1} \\ \overrightarrow{V_2} \end{array}\right] \\ & =\left(\left[\begin{array}{c} \operatorname{softmaxed}\left(Q_1 K_1^T\right) \times \overrightarrow{V_1} \\ \operatorname{softmaxed}\left(Q_2 K_1^T\right) \times \overrightarrow{V_1}+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \times \overrightarrow{V_2} \end{array}\right]\right) \end{aligned} Attstep 1(Q,K,V)=softmax([Q1K1TQ2K1T−∞Q2K2T])[V1V2]=([softmaxed(Q1K1T)softmaxed(Q2K1T)0softmaxed(Q2K2T)])[V1V2]=([softmaxed(Q1K1T)×V1softmaxed(Q2K1T)×V1+softmaxed(Q2K2T)×V2])
假设 A t t 1 Att_1 Att1表示Attention结果的第一行, A t t 2 Att_2 Att2表示Attention结果的第二行,那么有下面的表示:
A t t 1 ( Q , K , V ) = softmaxed ( Q 1 K 1 T ) V ⃗ 1 A t t 2 ( Q , K , V ) = softmaxed ( Q 2 K 1 T ) V ⃗ 1 + softmaxed ( Q 2 K 2 T ) V ⃗ 2 \begin{aligned} & Att_1(Q, K, V)=\operatorname{softmaxed}\left(Q_1 K_1^T\right) \vec{V}_1 \\ & Att_2(Q, K, V)=\operatorname{softmaxed}\left(Q_2 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \vec{V}_2 \end{aligned} Att1(Q,K,V)=softmaxed(Q1K1T)V1Att2(Q,K,V)=softmaxed(Q2K1T)V1+softmaxed(Q2K2T)V2
可以看到:
- 在计算 A t t 1 Att_1 Att1时, Q 1 K 2 T Q_1K_2^T Q1K2T这个score会被mask掉,也就是当前token计算Attention时看不到后面的token信息
- 在计算 A t t 2 Att_2 Att2时,仅仅依赖于 Q 2 Q_2 Q2以及 K 1 T , K 2 T K_1^T, K_2^T K1T,K2T和 V ⃗ 1 , V ⃗ 2 \vec{V}_1, \vec{V}_2 V1,V2,与 Q 1 Q_1 Q1无关
第二次计算
如果没有KV Cache,模型进行第二次前向推理计算,生成习
时,计算过程如下:
此时,Attention的计算公式为:
Att 1 ( Q , K , V ) = softmaxed ( Q 1 K 1 T ) V ⃗ 1 Att 2 ( Q , K , V ) = softmaxed ( Q 2 K 1 T ) V ⃗ 1 + softmaxed ( Q 2 K 2 T ) V ⃗ 2 Att 3 ( Q , K , V ) = softmaxed ( Q 3 K 1 T ) V ⃗ 1 + softmaxed ( Q 3 K 2 T ) V ⃗ 2 + softmaxed ( Q 3 K 3 T ) V ⃗ 3 \begin{aligned} & \operatorname{Att}_1(Q, K, V)=\operatorname{softmaxed}\left(Q_1 K_1^T\right) \vec{V}_1 \\ & \operatorname{Att}_2(Q, K, V)=\operatorname{softmaxed}\left(Q_2 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \vec{V}_2 \\ & \operatorname{Att}_3(Q, K, V)=\operatorname{softmaxed}\left(Q_3 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_3 K_2^T\right) \vec{V}_2+\operatorname{softmaxed}\left(Q_3 K_3^T\right) \vec{V}_3 \end{aligned} Att1(Q,K,V)=softmaxed(Q1K1T)V1Att2(Q,K,V)=softmaxed(Q2K1T)V1+softmaxed(Q2K2T)V2Att3(Q,K,V)=softmaxed(Q3K1T)V1+softmaxed(Q3K2T)V2+softmaxed(Q3K3T)V3
同样的, A t t 3 Att_3 Att3的计算结果,之和 Q 3 Q_3 Q3有关,与 Q 1 , Q 2 Q_1, Q_2 Q1,Q2无关,因此,可以模型的推理可以简化为:
这就是KV Cache的计算过程
第三次计算
同理,生成最后AI
时,使用KV Cache计算过程如下:
A t t 4 Att_4 Att4的计算公式为:
Att 4 ( Q , K , V ) = softmaxed ( Q 4 K 1 T ) V 1 → + softmaxed ( Q 4 K 2 T ) V 2 → + softmaxed ( Q 4 K 3 T ) V 3 → + softmaxed ( Q 4 K 4 T ) V 4 → \begin{aligned} \operatorname{Att}_4(Q, K, V)&=\operatorname{softmaxed}\left(Q_4 K_1^T\right) \overrightarrow{V_1} +\operatorname{softmaxed}\left(Q_4 K_2^T\right) \overrightarrow{V_2} \\ & +\operatorname{softmaxed}\left(Q_4 K_3^T\right) \overrightarrow{V_3} +\operatorname{softmaxed}\left(Q_4 K_4^T\right) \overrightarrow{V_4} \end{aligned} Att4(Q,K,V)=softmaxed(Q4K1T)V1+softmaxed(Q4K2T)V2+softmaxed(Q4K3T)V3+softmaxed(Q4K4T)V4
因此,可以看出:
- 不使用KV Cache的方法,存在大量冗余的计算,也就是要生成 A t t k Att_k Attk时,还需要重复计算 A t t 1 , . . , A t t k − 1 Att_1, .., Att_{k-1} Att1,..,Attk−1
- 计算 A t t k Att_k Attk时,之和 Q k Q_k Qk有关,与之前的 Q 1 , . . . , Q k − 1 Q_1, ... , Q_{k-1} Q1,...,Qk−1都没关系
- 生成第 x k x_k xk个token时,只需要输入上一轮生成的 x k − 1 x_{k-1} xk−1即可
所以每一步其实只需要根据 Q k Q_k Qk 计算 A t t k Att_k Attk 就可以,但是 K K K 和 V V V 是全程参与计算的。从优化推理速度角度来看,只需要把每一步的 K , V K,V K,V 缓存起来就可以, 所以叫 KV Cache。
最后需要注意当 sequence 比较长,或者 batch 特别大的时候,KV Cache 其实还是个Memory刺客,所以如何减少 KV 的内存变得尤为重要。
目前各种框架,针对 KV Cache 做了优化,比如 vLLM 的 Page Attention, Prefix Caching,Token 的稀疏化,KV 共享或者压缩(MQA、GQA 和 MLA),LayerSkip,Mooncake 等等,可以说 KV Cache 目前是推理的基石,各种基于 KV Cache 的优化方法撑起了大模型推理加速的半壁江山。
参考资料
- [1] https://note.mowen.cn/note/detail?noteUuid=Uwudr7Hu_STXSicuxHei2