简单来说,LLM在decoding阶段的每次推理只会用到当前的Q,这次用的Q下次不会用到,所以不用Cache Q;
但是每次都要用到当前和过去所有的KV,这次用到的KV下次马上就要再用一次,所以Cache KV可以加速推理。
下面说明原因:
观察Attention公式,这个K和Q怎么看都很对称,为什么只Cache K而不Cache Q?
把KQV写成分块的形式,像这样:
然后Q和K转置的矩阵乘就变成了这样:
直到这一步,K和Q看上去都很对称。轮换一下K和Q对结果没有本质影响。
V的引入破坏了这一对称性。忽略 𝑑𝑘 系数,第i行的softmax简写成 𝑆𝑖 ,attention操作的结果变成了这样:
这是没有Causal Mask(因果掩码)的情况。加入Causal Mask会变成这样:
可以写一下结果的通项,没有Causal Mask:
有Causal Mask:
无论有没有Causal Mask,Q和K在结果中都是不对称的。
在序列的t位置,Q只有当前位置的 𝑞𝑡q_t 参与了计算,而K和V多个位置参与了计算,所以需要KV Cache,而不需要Q Cache。
在没有Causal Mask时,计算t位置的Attention需要未来的KV,这在实际进行自回归推理时无法得到;加上Causal Mask之后,只需要1,2,…,t位置的KV就可以进行推理。