KV-Cache详解

KV-Cache

KV-Cache是一种优化技术,用于 Transformer 模型的自注意力层,以提高计算效率。本文探讨了 KV 缓存的概念、应用以及它在自注意力机制中优化的特定计算。并回答下面3个问题:

  • 什么是KV Cache?
  • KV Cache在哪里使用?
  • KV Cache节省了Self-Attention层中哪部分的计算?

在深入了解 KV-Caching 之前,首先需要理解 Transformer 中如何使用KV。具体的,大家可以参考这篇文章:Scaled Dot-Product Attention详解


什么是KV Cache

在自回归模型中(autoregressive models),为了优化文本生成过程,我们需要使用 KV cache。这种模型会逐个生成文本的每个token,这个过程可能比较慢,因为模型一次只能生成一个token,而且每次新的预测都依赖于之前的上下文。这意味着,要预测第1000个token,你需要用到前999个token的信息,这通常涉及到对这些token的表示进行一系列矩阵乘法运算。而要预测第1001个token,你不仅需要前999个token的信息,还要加上第1000个token的信息。KV cache就是在这里发挥作用,通过存储之前 K , V K,V K,V的计算结果,并在后续的token生成时复用这些结果,从而避免重复计算。

更具体地说,KV cache在自回归生成模型中充当一个内存库的角色,模型会存储来自自注意力层(self-attention layers)的之前处理过的token的键值对。在Transformer架构中,Self Attention层通过将Queries(Q)与Keys(K)相乘来计算注意力分数,然后产生Values(V)向量的加权和作为输出。通过存储这些信息,模型可以避免重复的计算,而是直接从缓存中检索之前token的 K K K V V V

需要注意的是,KV cache只在多个token生成步骤中发生,并且仅在decoder进行(例如,在decoder only的模型如GPT,或在encoder-decoder模型如T5的解码部分)。像BERT这样的encoder only模型,不是生成型的,因此不涉及KV cache。

由于decoder是causal的(即,一个token的注意力attention只依赖于它前面的token),在每一步生成过程中,我们实际上是在重复计算相同的前一个token的注意力,而我们真正需要做的是仅计算新token的注意力。这就是KV cache发挥作用的地方。通过缓存之前的K和V,我们可以专注于只计算新token的注意力。
在这里插入图片描述

图片来源:Transformers KV Caching Explained

  1. 为什么只存储K和V,而不存储Q?

    因为decoder模型是causal的(即,一个token的注意力只依赖于它前面的token)。Transformers是一种自回归模型。它们会参照序列中的所有先前输入来预测下一个输出(token)。但是由于它们不像RNN那样一个接一个地接收序列,而是一次性接受整个序列,因此需要一种方法来限制任何位置的注意力范围,使之只关注它之前的那部分序列。这是因为希望模型在推断下一个位置时,不能提前看到答案。如果它能够看到下一个位置,那么它只会复制它。

    look-ahead mask实现了这一点。在序列中,出现在查询词右侧的所有单词的注意力得分都被遮蔽了。这限制了查询词只关注自身以及在序列中位于其左侧的所有单词。

    look-ahead mask只应用于每个解码器层的第一个注意力子层。这是因为解码器是推断发生的地方。所以,正是在这里模型不应该提前窥视。

    更多细节可以参考: Understanding Attention In Transformers Models

    举个例子🌰:

    假设我们的 Q , K , V Q,K,V Q,K,V分别如下:
    Q = [ 0.212 0.04 0.63 0.36 0.1 0.14 0.86 0.77 0.31 0.36 0.19 0.72 ] ,   K = [ 0.31 0.84 0.963 0.57 0.45 0.94 0.73 0.58 0.36 0.83 0.1 0.38 ] ,   V = [ 0.36 0.83 0.1 0.38 0.31 0.36 0.19 0.72 0.31 0.84 0.963 0.57 ] Q=\begin{bmatrix} 0.212 & 0.04 & 0.63 & 0.36\\ 0.1 & 0.14 & 0.86 & 0.77\\ 0.31 & 0.36 & 0.19 & 0.72 \end{bmatrix}, \ K = \begin{bmatrix} 0.31 & 0.84 & 0.963 & 0.57\\ 0.45 & 0.94 & 0.73 & 0.58\\ 0.36 & 0.83 & 0.1 & 0.38 \end{bmatrix}, \ V = \begin{bmatrix} 0.36 & 0.83 & 0.1 & 0.38\\ 0.31 & 0.36 & 0.19 & 0.72\\ 0.31 & 0.84 & 0.963 & 0.57 \end{bmatrix} Q= 0.2120.10.310.040.140.360.630.860.190.360.770.72 , K= 0.310.450.360.840.940.830.9630.730.10.570.580.38 , V= 0.360.310.310.830.360.840.10.190.9630.380.720.57
    在计算完 Q K T d k \frac{QK^T}{\sqrt{d_k}} dk QKT得到attention矩阵后,我们创建一个masking矩阵,将其与attetnion矩阵想加:
    M = [ 0 1 1 0 0 1 0 0 0 ] × − 1 e 9 = [ 0 − 1 e 9 − 1 e 9 0 0 − 1 e 9 0 0 0 ] M=\begin{bmatrix} 0 & 1 & 1\\ 0 & 0 & 1\\ 0 & 0 & 0 \end{bmatrix} \times -1e9 = \begin{bmatrix} 0 & -1e9 & -1e9\\ 0 & 0 & -1e9\\ 0 & 0 & 0 \end{bmatrix} M= 000100110 ×1e9= 0001e9001e91e90

    − 1 e 9 -1e9 1e9: 一个极小值

    Q K T d k + M = [ 0.455605 0.40085 0.15466 0.70784 0.6255 0.2654 0.495935 0.5171 0.3515 ] + [ 0 − 1 e 9 − 1 e 9 0 0 − 1 e 9 0 0 0 ] = [ 0.455605 − 1 e 9 − 1 e 9 0.70784 0.6255 − 1 e 9 0.495935 0.5171 0.3515 ] \frac{QK^T}{\sqrt{d_k}} + M = \begin{bmatrix} 0.455605 & 0.40085 & 0.15466\\ 0.70784 & 0.6255 & 0.2654\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix} + \begin{bmatrix} 0 & -1e9 & -1e9\\ 0 & 0 & -1e9\\ 0 & 0 & 0 \end{bmatrix} = \begin{bmatrix} 0.455605 & -1e9 & -1e9\\ 0.70784 & 0.6255 & -1e9\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix} dk QKT+M= 0.4556050.707840.4959350.400850.62550.51710.154660.26540.3515 + 0001e9001e91e90 = 0.4556050.707840.4959351e90.62550.51711e91e90.3515
    接下来,沿行应用 softmax,将这些值转换为概率分布。将 s o f t m a x softmax softmax应用于注意力矩阵后,所有这些极小的值( − 1 e 9 -1e9 1e9)都将变为零:
    s o f t m a x ( [ 0.455605 − 1 e 9 − 1 e 9 0.70784 0.6255 − 1 e 9 0.495935 0.5171 0.3515 ] ) = [ 1.0 0 0 0.520573 0.479427 0 0.346392 0.353802 0.299806 ] softmax(\begin{bmatrix} 0.455605 & -1e9 & -1e9\\ 0.70784 & 0.6255 & -1e9\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix})=\begin{bmatrix} 1.0 & 0 & 0\\ 0.520573 & 0.479427 & 0\\ 0.346392 & 0.353802 & 0.299806 \end{bmatrix} softmax( 0.4556050.707840.4959351e90.62550.51711e91e90.3515 )= 1.00.5205730.34639200.4794270.353802000.299806

    再来看一下不存储 Q Q Q的情况,仅存储 K K K V V V的情况(为了简洁,此处忽略了与 V V V矩阵计算的相关内容):
    Q 1 = [ 0.212 0.04 0.63 0.36 — — — — — — — — ] ,   Q 2 = [ — — — — 0.1 0.14 0.86 0.77 — — — — ] ,   Q 3 = [ — — — — — — — — 0.31 0.36 0.19 0.72 ] K 1 = [ 0.31 0.84 0.963 0.57 — — — — — — — — ] ,   K 2 = [ 0.31 0.84 0.963 0.57 0.45 0.94 0.73 0.58 — — — — ] ,   K 3 = [ 0.31 0.84 0.963 0.57 0.45 0.94 0.73 0.58 0.36 0.83 0.1 0.38 ] Q_1=\begin{bmatrix} 0.212 & 0.04 & 0.63 & 0.36\\ — & — & — & —\\ — & — & — & — \end{bmatrix}, \ Q_2=\begin{bmatrix} — & — & — & —\\ 0.1 & 0.14 & 0.86 & 0.77\\ — & — & — & — \end{bmatrix}, \ Q_3=\begin{bmatrix} — & — & — & —\\ — & — & — & —\\ 0.31 & 0.36 & 0.19 & 0.72 \end{bmatrix} \\ \\ K_1 = \begin{bmatrix} 0.31 & 0.84 & 0.963 & 0.57\\ — & — & — & —\\ — & — & — & — \end{bmatrix}, \ K_2 = \begin{bmatrix} 0.31 & 0.84 & 0.963 & 0.57\\ 0.45 & 0.94 & 0.73 & 0.58\\ — & — & — & — \end{bmatrix}, \ K_3 = \begin{bmatrix} 0.31 & 0.84 & 0.963 & 0.57\\ 0.45 & 0.94 & 0.73 & 0.58\\ 0.36 & 0.83 & 0.1 & 0.38 \end{bmatrix} Q1= 0.2120.040.630.36 , Q2= 0.10.140.860.77 , Q3= 0.310.360.190.72 K1= 0.310.840.9630.57 , K2= 0.310.450.840.940.9630.730.570.58 , K3= 0.310.450.360.840.940.830.9630.730.10.570.580.38

    Q 1 K 1 T d k = [ 0.455605 — — — — — — — — ] ,   Q 2 K 2 T d k = [ — — — 0.70784 0.6255 — — — — ] ,   Q 3 K 3 T d k = [ — — — — — — 0.495935 0.5171 0.3515 ] \frac{Q_{1}K_{1}^{T}}{\sqrt{d_k}}=\begin{bmatrix} 0.455605 & — & — \\ — & — & — \\ — & — & — \end{bmatrix}, \ \frac{Q_{2}K_{2}^{T}}{\sqrt{d_k}}=\begin{bmatrix} — & — & — \\ 0.70784 & 0.6255 & — \\ — & — & — \end{bmatrix}, \ \frac{Q_{3}K_{3}^{T}}{\sqrt{d_k}}=\begin{bmatrix} — & — & — \\ — & — & — \\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix} dk Q1K1T= 0.455605 , dk Q2K2T= 0.707840.6255 , dk Q3K3T= 0.4959350.51710.3515

    相加后的结果与存储 Q Q Qmasking的结果相同:
    Q 1 K 1 T d k + Q 2 K 2 T d k + Q 3 K 3 T d k = [ 0.455605 − − 0.70784 0.6255 − 0.495935 0.5171 0.3515 ] \frac{Q_{1}K_{1}^{T}}{\sqrt{d_k}} + \frac{Q_{2}K_{2}^{T}}{\sqrt{d_k}} + \frac{Q_{3}K_{3}^{T}}{\sqrt{d_k}}=\begin{bmatrix} 0.455605 & - & -\\ 0.70784 & 0.6255 & -\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix} dk Q1K1T+dk Q2K2T+dk Q3K3T= 0.4556050.707840.4959350.62550.51710.3515
    应用 s o f t m a x softmax softmax:
    s o f t m a x ( [ 0.455605 − − 0.70784 0.6255 − 0.495935 0.5171 0.3515 ] ) = [ 1.0 − − 0.520573 0.479427 − 0.346392 0.353802 0.299806 ] softmax(\begin{bmatrix} 0.455605 & - & -\\ 0.70784 & 0.6255 & -\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix})=\begin{bmatrix} 1.0 & - & -\\ 0.520573 & 0.479427 & -\\ 0.346392 & 0.353802 & 0.299806 \end{bmatrix} softmax( 0.4556050.707840.4959350.62550.51710.3515 )= 1.00.5205730.3463920.4794270.3538020.299806

    可以观察到,与存储 Q Q Q时的结果是一致的,这也代表在接下来与 V V V矩阵计算得到的Attention的结果也将一样,这也就是为什么我们在KV Cahce时不需要存储 Q Q Q的原因。

KV Cache在哪里使用?

我们每生成一个新的token就会把这个新的token append进之前的序列中,在将这个序列当作新的输入进行新的token生成,直到eos_token结束。这使得每次新序列输入时都需要取重复计算前面的 n − 1 n-1 n1个token的 q , k , v q,k,v q,k,v,浪费了很多资源,KV Cache就是在这里使用的,我们在每次处理新的序列时,可以同时将之前计算的key, value一同缓存,并传入下一次计算,这样就节省了很多计算的时间,避免了冗余计算。

KV Cache节省了Self-Attention层中哪部分的计算?

首先,我们要知道,Self-Attention通过将输入序列变换成三个向量来操作:查询向量(Query),键向量(Key)和值向量(Value)。这些向量是通过对输入进行线性变换得到的。注意力机制基于 Q Q Q向量和 K K K向量之间的相似度来计算 V V V向量的加权求和。然后,将这个加权求和的结果连同原始输入一起传递给前馈神经网络,以产生最终输出。这个过程允许模型专注于相关信息并捕捉长距离依赖关系。

那么回到问题,它节省了哪部分计算呢?它节省了对于键(Key)和值(Value)的重复计算,不需要对之前已经计算过的Token的 K K K V V V重新进行计算。因为对于之前的Token我们可以使用上一轮计算的结果,避免了重复计算,只需要计算当前Token的 Q   K   V Q \ K \ V Q K V

注意,它并没有节省Scaled Dot-Product Attention环节的计算


如果觉得这篇文章有用,就给个👍和收藏⭐️吧!也欢迎在评论区分享你的看法!


参考

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值