为什么Transformer推理要做KV缓存?

我是丁师兄,专注于智能驾驶方向大模型落地,公众号:丁师兄大模型

大模型1v1学习,已帮助多名同学上岸国内外大厂


前几天有个网友偷偷给我发私信,那叫一个信心满满地冲向阿里面试大模型岗位,结果悲催了,铩羽而归。到底是哪里出了岔子呢?

这篇文章,我们就从面试官的视角,来拆解这个 Transformer 的面试连环炮。如果是你在面试现场被这些问题 “轰炸”,到底该咋回答呢?嘿嘿,一起看看吧!

01为什么Transformer推理要做KV缓存?

面试官心理分析

第一个问题一般都是先热热身,面试官问这个呢,其实是想看看,你知不知道 Transformer 的一个最基本的工作原理。

面试题解析

因为在 Transformer 中文本是逐个 token 生成的,每次新的预测,会基于之前生成的所有 token 的上下文信息。

这种对顺序数据的依赖会减慢生成过程,因为每次预测下一个 token,都需要重新处理序列中所有之前的 token。

举个例子,如果我们要预测第 100 个 token,那模型必须使用前面 99 个 token 的信息,这就需要对这些 token 进行矩阵运算,而这个矩阵运算是非常复杂和耗时的。

因此 KV 缓存就是为了减少这种耗时的矩阵运算,在 Transformer 推理过程中,会将这些键和值存储在缓存中,这样模型就可以在生成后续 token 的时候,直接访问缓存,而不需要重新计算。

02KV 缓存的工作机制的什么样的?

OK,那既然你知道 KV 缓存是为了减少生成 token 时的矩阵运算,那它具体是怎么减小的?讲讲 KV 缓存的工作机制。

这个时候,你最好能结合画图讲出来,会很清晰。干讲很多东西讲不清楚,几张图就能说明白。

前面说了,由于解码器是因果的,也就是一个 token 的注意力只依赖于其前面的 token,在每个生成步骤中,我们希望只需要计算最新那个 token 的注意力,不需要重新计算相同的先前 token 注意力。

来看下面这两张图,分别是有缓存和没有缓存的情况。对于第一个新到的词,假设之前没有缓存过,每次 query 的 token 会和 K 做矩阵相乘,然后做 softmax 计算注意力得分,然后再和V相乘再输出。

我们对比一下上下两张图,其实有缓存和没有缓存的计算量是一样的,不过在下图中,我们会把这一步计算过的 K,V 缓存起来。

然后是第二个 token,这时候可以看出,紫色的是缓存的 K V,在没有缓存的时候需要重新计算一次 K V,而如果做了 K V 缓存,那么只需要把历史的 K, V 拿出来,同时只计算最新那个 token 的 K, V,并拼接成一个大矩阵就可以了。

对比一下,有缓存和没有缓存的计算量明显减少了一半。

第三个第四个 token 也一样,每次历史计算过的 K V 就不用重新计算了,从而极大的减少了 self attention 的计算量,从序列长度的二次方变成了线性。

03带有 KV 缓存优化的大模型推理过程包含几个阶段?

OK,那这个时候一些面试官会继续追问:那实际在解码的时候,一个典型的带有 KV 缓存优化的大模型推理过程包含几个阶段?

这个时候你首先要答出包含几个阶段实际在解码的时候,一个典型的带有 KV cache 优化的生成大模型的推理过程会包含两个阶段:prefill 和 decoding

然后再回答每个阶段干的事情:

  • prefill 阶段:输入一个 prompt 序列,为每个 transformer 层生成 KV cache,同时输出第一个 token。

  • decoding 阶段:发生在计算第二个输出 token 至最后一个 token 过程中,这时 Cache 是有值的,每轮推理只需读取 Cache,同时将当前轮计算出的新的 Key、Value 追加写入至 Cache;FLOPs 降低,gemm 变为 gemv 操作,推理速度相对第一阶段变得更快。

04如何估算KV缓存消耗的显存大小呢?

好,那下一个问题:如何估算 KV 缓存消耗的显存大小呢?

我们首先要知道,KV 缓存通常使用 float16 或者 bfloat16 数据类型,也就是以 16 位的精度存储张量。对于一个 token,KV 缓存会为每一层和每个注意力头存储一对 KV 张量。

KV 张量的总显存消耗,可以通过下面这个公式计算,这里我们以字节为单位。

层数 × KV注意力头的数量 × 注意力头的维度 × (位宽/8) × 2

解释一下这个公式,最后的 "2" 是因为有两组张量,也就是键和值。位宽通常为 16 位,由于 8 位 1 字节,因此将位宽除以 8,这样在 KV 缓存中,每 16 位参数占用 2 个字节。

05KV 缓存为何在长文本和复杂模型结构下成瓶颈?

既然你知道怎么估算 KV 缓存,那面试官可能继续问你:为什么 KV 缓存在长文本和复杂模型结构场景下会成为瓶颈呢?

这个时候我们可以用实际的例子跟面试官说明,以 Llama 3 8B 为例,上面公式就变为:

32 × 8 × 128 × 2 × 2 = 131,072字节

这里注意,虽然 Llama 3 8B 有 32 个注意力头。不过由于 GQA 的存在,只有 8 个注意力头用于键和值,对于一个 token,KV 缓存占用 131,072 字节,差不多 0.1 MB。

这看起来好像不大,但对于许多不同类型的应用,大模型需要生成成千上万的 tokens。

举个例子,如果我们想利用 Llama 3 8B 的全部上下文大小是 8192,KV 缓存将为 8191 个 token 存储键值张量,差不多会占用 1.1 G 显存。

换句话说,对于一块 24 G 显存的消费级 GPU 来说,KV 缓存将占用其总显存的 4.5%。

对于更大的模型,KV 缓存增长得更快。比如对于 Llama 3 70B,它有 80 层,公式变为:

80 × 8 × 128 × 2 × 2 = 327,680 字节

对于 8191 个 token,Llama 3 70B 的 KV 缓存将占用 2.7 GB。注意这只是单个序列的显存消耗,如果我们进行批量解码,还需要将这个值乘以 batch size。

举个例子,使用 batch size 为 32 的 Llama 3 8B 模型将需要 35.2 GB 的 GPU 显存,一块消费级 GPU 就显然搞不定了。

所以 KV 缓存会随着文本长度增加和模型结构的复杂性增强大幅膨胀,对于需要处理超长上下文的应用可能会成为瓶颈。

06KV 缓存缺陷的解决方案有哪些?

那下一个问题自然就来了:针对 KV 缓存这个缺陷,有什么解决方案?

答案是:量化 KV 缓存,量化可以降低参数精度,比如从 16 位降到 4 位。换句话说,一个 16 位张量可以缩减到原来的 1/4,而通过两位量化甚至可以缩减到 1/8。

所以上面的例子,理论上 4 bit 量化可以将 KV 缓存的大小从 35.2 G 缩减到 8.8 G,使用 2bit 量化甚至可以缩减到 4.4 GB。当然实际效果取决于所使用的算法及其超参数,需要结合具体情况分析。

那量化有没有什么缺点呢?

当然有,没有哪个方案是完美的。量化可能会减慢解码过程,并且可能显著影响 LLM 的准确性。

实际中需要通过调整量化超参,减轻这些影响,另外可以通过重写量化算子,对量化操作和反量化操作与其他算子做一些融合,来提高降低解码速度的影响。

07是否了解有关 KV 缓存的最新技术?

那最后面试官可能还会问你一些开放性的问题,比如有没有了解过关于 KV 缓存的一些最新工作?

这个呢,就是看你的技术热情,有没有对业界最新的技术/paper,保持敏感度。

你可以这么来回答:对于 KV 缓存这块,今年 KV-cache 复用从层内复用跨越到层间服用。微软最近新发了一篇论文,提出了 YOCO(You Only Cache Once),这是一个 KV Cache 层间共享的新思路。

同期 MIT-IBM Watson AI Lab 也发了一篇类似的论文,提出了 CLA( Cross-Layer Attention),即 KV Cache 跨层推理。这和 YOCO 不谋而合。

YOCO 的名称借鉴了单阶段目标检测始祖 YOLO 的风格。整篇论文提出的最核心的思想 KV Cache 共享方式,即层间共享。并且基于此,将上下文长度扩展到了 1 百万。

我们知道,目前最常见的 KV Cache 共享策略是 MQA/GQA。从 Layer 的视角来看,MQA/GQA 可以认为是层内 KV Cache 共享,而 YOCO 提出的想法,则可以认为是 Inter-Layer 层间 KV Cache 共享。

这个算法理论上的可以最多把 KV Cache 的 Memory 需求降低到 1/N(N 为 Transformer 层数)。

并且,这和层内 KV Cache 共享的技术,比如 MQA 和 GQA 是不冲突的,两者可以一起使用,从而极大地降低 KV Cache 的显存开销。

END


我是丁师兄,专注于智能驾驶方向大模型落地,公众号:丁师兄大模型

大模型1v1学习,已帮助多名同学上岸国内外大厂

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值