【大模型推理】KV Cache原理

基本原理

两句话说明 KV Cache 的原理:

  • 由于Attention中的mask机制(前面token的Q,看到不到后面token的K和V),当前生成的token只依赖于 当前输入的token以及之前所有token的KV矩阵
  • 因此 可以通过 缓存 注意力机制中的键值对,在自回归生成过程中减少重复计算,加速大模型的推理过程

矩阵运算

要想理解 KV Cache,需要先了解一下矩阵运算的一个基本性质:分块矩阵计算法则,也就是两个矩阵相乘可以拆分成 列向量与行向量运算的和,在张量并行中也有用到:

[ X 1 X 2 ] × [ A 1 A 2 ] = [ X 1 A 1 + X 2 A 2 ] = Y \left[\begin{array}{ll} X_1 & X_2 \end{array}\right] \times\left[\begin{array}{l} A_1 \\ A_2 \end{array}\right]=\left[X_1 A_1+X_2 A_2\right]=Y [X1X2]×[A1A2]=[X1A1+X2A2]=Y

在这个性质的基础上,此时如果矩阵 X X X 是一个下三角矩阵,就有:

[ X 1 , 1 0 ⋯ 0 X 2 , 1 X 2 , 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ X m , 1 X m , 2 ⋯ X m , n ] ⋅ [ Y 1 , 1 Y 1 , 2 ⋯ Y 1 , p Y 2 , 1 Y 2 , 2 ⋯ Y 2 , p ⋮ ⋮ ⋱ ⋮ Y n , 1 Y n , 2 ⋯ Y n , p ] = [ X 1 , 1 Y ⃗ 1 X 2 , 1 Y ⃗ 1 + X 2 , 2 Y ⃗ 2 ⋮ X m , 1 Y ⃗ 1 + X m , 2 Y ⃗ 2 + ⋯ + X m , n Y ⃗ n ] \begin{aligned} & {\left[\begin{array}{cccc} X_{1,1} & 0 & \cdots & 0 \\ X_{2,1} & X_{2,2} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ X_{m, 1} & X_{m, 2} & \cdots & X_{m, n} \end{array}\right] \cdot\left[\begin{array}{cccc} Y_{1,1} & Y_{1,2} & \cdots & Y_{1, p} \\ Y_{2,1} & Y_{2,2} & \cdots & Y_{2, p} \\ \vdots & \vdots & \ddots & \vdots \\ Y_{n, 1} & Y_{n, 2} & \cdots & Y_{n, p} \end{array}\right]} \\ & =\left[\begin{array}{l} X_{1,1} \vec{Y}_1 \\ X_{2,1} \vec{Y}_1+X_{2,2} \vec{Y}_2 \\ \vdots \\ X_{m, 1} \vec{Y}_1+X_{m, 2} \vec{Y}_2+\cdots+X_{m, n} \vec{Y}_n \end{array}\right] \end{aligned} X1,1X2,1Xm,10X2,2Xm,200Xm,n Y1,1Y2,1Yn,1Y1,2Y2,2Yn,2Y1,pY2,pYn,p = X1,1Y 1X2,1Y 1+X2,2Y 2Xm,1Y 1+Xm,2Y 2++Xm,nY n

可以看到,结果矩阵的第 k k k 行只用到了矩阵 X X X 的 第 k k k 个行向量。所以 X X X 不需要进行全部的矩阵乘法,每一步只取第 k k k 个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。

如果把矩阵 X X X 认为是 Attention 计算得到的 QK score,矩阵 Y Y Y 认为是 V V V,上述计算流程基本上就是 KV Cache 的框架了。这里要注意一点,因为 现在大模型基本都是Decoder-only的架构 ,自回归生成的过程中,当前token在做attention计算时是看不到后面的,后面的内容都被 mask 了。所以 KV Cache 只有 Decoder-only架构才有,Encoder-only 比如 BERT模型 是没有的

KV Cache原理

下面用一个具体的例子来说明,KV Cache假设当前文本为 我要, 要 大模型 输出的内容为学习AI

第一次计算

模型第一次计算,生成时的过程如下:
在这里插入图片描述

为了方便演示,忽略scale项 d \sqrt{d} d ,最终Attention的计算公式如下,(softmaxed 表示已经按行进行了softmax):

A t t step  1 ( Q , K , V ) = softmax ⁡ ( [ Q 1 K 1 T − ∞ Q 2 K 1 T Q 2 K 2 T ] ) [ V 1 → V 2 → ] = ( [ softmaxed ⁡ ( Q 1 K 1 T ) 0 softmaxed ⁡ ( Q 2 K 1 T ) softmaxed ⁡ ( Q 2 K 2 T ) ] ) [ V 1 → V 2 → ] = ( [ softmaxed ⁡ ( Q 1 K 1 T ) × V 1 → softmaxed ⁡ ( Q 2 K 1 T ) × V 1 → + softmaxed ⁡ ( Q 2 K 2 T ) × V 2 → ] ) \begin{aligned} Att_{\text {step } 1}(Q, K, V)&=\operatorname{softmax}\left(\left[\begin{array}{cc} Q_1 K_1^T & -\infty \\ Q_2 K_1^T & Q_2 K_2^T \end{array}\right]\right)\left[\begin{array}{l} \overrightarrow{V_1} \\ \overrightarrow{V_2} \end{array}\right] \\ & =\left(\left[\begin{array}{cc} \operatorname{softmaxed}\left(Q_1 K_1^T\right) & 0 \\ \operatorname{softmaxed}\left(Q_2 K_1^T\right) & \operatorname{softmaxed}\left(Q_2 K_2^T\right) \end{array}\right]\right)\left[\begin{array}{l} \overrightarrow{V_1} \\ \overrightarrow{V_2} \end{array}\right] \\ & =\left(\left[\begin{array}{c} \operatorname{softmaxed}\left(Q_1 K_1^T\right) \times \overrightarrow{V_1} \\ \operatorname{softmaxed}\left(Q_2 K_1^T\right) \times \overrightarrow{V_1}+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \times \overrightarrow{V_2} \end{array}\right]\right) \end{aligned} Attstep 1(Q,K,V)=softmax([Q1K1TQ2K1TQ2K2T])[V1 V2 ]=([softmaxed(Q1K1T)softmaxed(Q2K1T)0softmaxed(Q2K2T)])[V1 V2 ]=([softmaxed(Q1K1T)×V1 softmaxed(Q2K1T)×V1 +softmaxed(Q2K2T)×V2 ])

假设 A t t 1 Att_1 Att1表示Attention结果的第一行, A t t 2 Att_2 Att2表示Attention结果的第二行,那么有下面的表示:

A t t 1 ( Q , K , V ) = softmaxed ⁡ ( Q 1 K 1 T ) V ⃗ 1 A t t 2 ( Q , K , V ) = softmaxed ⁡ ( Q 2 K 1 T ) V ⃗ 1 + softmaxed ⁡ ( Q 2 K 2 T ) V ⃗ 2 \begin{aligned} & Att_1(Q, K, V)=\operatorname{softmaxed}\left(Q_1 K_1^T\right) \vec{V}_1 \\ & Att_2(Q, K, V)=\operatorname{softmaxed}\left(Q_2 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \vec{V}_2 \end{aligned} Att1(Q,K,V)=softmaxed(Q1K1T)V 1Att2(Q,K,V)=softmaxed(Q2K1T)V 1+softmaxed(Q2K2T)V 2

可以看到:

  • 在计算 A t t 1 Att_1 Att1时, Q 1 K 2 T Q_1K_2^T Q1K2T这个score会被mask掉,也就是当前token计算Attention时看不到后面的token信息
  • 在计算 A t t 2 Att_2 Att2时,仅仅依赖于 Q 2 Q_2 Q2以及 K 1 T , K 2 T K_1^T, K_2^T K1T,K2T V ⃗ 1 , V ⃗ 2 \vec{V}_1, \vec{V}_2 V 1,V 2,与 Q 1 Q_1 Q1无关

第二次计算

如果没有KV Cache,模型进行第二次前向推理计算,生成时,计算过程如下:

在这里插入图片描述

此时,Attention的计算公式为:

Att ⁡ 1 ( Q , K , V ) = softmaxed ⁡ ( Q 1 K 1 T ) V ⃗ 1 Att ⁡ 2 ( Q , K , V ) = softmaxed ⁡ ( Q 2 K 1 T ) V ⃗ 1 + softmaxed ⁡ ( Q 2 K 2 T ) V ⃗ 2 Att ⁡ 3 ( Q , K , V ) = softmaxed ⁡ ( Q 3 K 1 T ) V ⃗ 1 + softmaxed ⁡ ( Q 3 K 2 T ) V ⃗ 2 + softmaxed ⁡ ( Q 3 K 3 T ) V ⃗ 3 \begin{aligned} & \operatorname{Att}_1(Q, K, V)=\operatorname{softmaxed}\left(Q_1 K_1^T\right) \vec{V}_1 \\ & \operatorname{Att}_2(Q, K, V)=\operatorname{softmaxed}\left(Q_2 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \vec{V}_2 \\ & \operatorname{Att}_3(Q, K, V)=\operatorname{softmaxed}\left(Q_3 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_3 K_2^T\right) \vec{V}_2+\operatorname{softmaxed}\left(Q_3 K_3^T\right) \vec{V}_3 \end{aligned} Att1(Q,K,V)=softmaxed(Q1K1T)V 1Att2(Q,K,V)=softmaxed(Q2K1T)V 1+softmaxed(Q2K2T)V 2Att3(Q,K,V)=softmaxed(Q3K1T)V 1+softmaxed(Q3K2T)V 2+softmaxed(Q3K3T)V 3

同样的, A t t 3 Att_3 Att3的计算结果,之和 Q 3 Q_3 Q3有关,与 Q 1 , Q 2 Q_1, Q_2 Q1,Q2无关,因此,可以模型的推理可以简化为:

在这里插入图片描述

这就是KV Cache的计算过程

第三次计算

同理,生成最后AI时,使用KV Cache计算过程如下:

在这里插入图片描述

A t t 4 Att_4 Att4的计算公式为:

Att ⁡ 4 ( Q , K , V ) = softmaxed ⁡ ( Q 4 K 1 T ) V 1 → + softmaxed ⁡ ( Q 4 K 2 T ) V 2 → + softmaxed ⁡ ( Q 4 K 3 T ) V 3 → + softmaxed ⁡ ( Q 4 K 4 T ) V 4 → \begin{aligned} \operatorname{Att}_4(Q, K, V)&=\operatorname{softmaxed}\left(Q_4 K_1^T\right) \overrightarrow{V_1} +\operatorname{softmaxed}\left(Q_4 K_2^T\right) \overrightarrow{V_2} \\ & +\operatorname{softmaxed}\left(Q_4 K_3^T\right) \overrightarrow{V_3} +\operatorname{softmaxed}\left(Q_4 K_4^T\right) \overrightarrow{V_4} \end{aligned} Att4(Q,K,V)=softmaxed(Q4K1T)V1 +softmaxed(Q4K2T)V2 +softmaxed(Q4K3T)V3 +softmaxed(Q4K4T)V4

因此,可以看出:

  • 不使用KV Cache的方法,存在大量冗余的计算,也就是要生成 A t t k Att_k Attk时,还需要重复计算 A t t 1 , . . , A t t k − 1 Att_1, .., Att_{k-1} Att1,..,Attk1
  • 计算 A t t k Att_k Attk时,之和 Q k Q_k Qk有关,与之前的 Q 1 , . . . , Q k − 1 Q_1, ... , Q_{k-1} Q1,...,Qk1都没关系
  • 生成第 x k x_k xk个token时,只需要输入上一轮生成的 x k − 1 x_{k-1} xk1即可

所以每一步其实只需要根据 Q k Q_k Qk 计算 A t t k Att_k Attk 就可以,但是 K K K V V V 是全程参与计算的。从优化推理速度角度来看,只需要把每一步的 K , V K,V K,V 缓存起来就可以, 所以叫 KV Cache。

最后需要注意当 sequence 比较长,或者 batch 特别大的时候,KV Cache 其实还是个Memory刺客,所以如何减少 KV 的内存变得尤为重要。

目前各种框架,针对 KV Cache 做了优化,比如 vLLM 的 Page Attention, Prefix Caching,Token 的稀疏化,KV 共享或者压缩(MQA、GQA 和 MLA),LayerSkip,Mooncake 等等,可以说 KV Cache 目前是推理的基石,各种基于 KV Cache 的优化方法撑起了大模型推理加速的半壁江山。

参考资料

  • [1] https://note.mowen.cn/note/detail?noteUuid=Uwudr7Hu_STXSicuxHei2
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

嗜睡的篠龙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值