上周面的一个985女生,问了Transformer模型的内存优化

我是丁师兄,专注于智能驾驶方向大模型落地,公众号:丁师兄大模型

大模型1v1学习,已帮助多名同学上岸国内外大厂


Transformer 模型现在很火,内存优化又很重要。上周面试了一个 985 大学的女生,跟她谈到了 Transformer 模型的内存优化问题。

那么这个女生到底给出了哪些关于 Transformer 模型内存优化的独特思路呢?一起来看看。

01什么是Transformer模型中的KV缓存?

Transformer 中文本是逐个 token 生成的,每次新的预测都会基于之前生成的所有 tokens 的上下文信息。

这种对顺序数据的依赖会减慢生成过程,因为每次预测下一个 token 都需要重新处理序列中所有之前的 tokens。

例如,要预测第 100 个 token,模型必须使用前 99 个 token 的信息,需要对这些 token 进行复杂的矩阵运算。预测第 101 个 token 时,也要对前 99 个 token 做类似计算,以及对第 100 个 token 的新计算。

如何简化呢?

答案是使用 KV 缓存。KV 缓存通过保存这些计算结果,使模型可以在生成后续 tokens 时直接访问这些结果,而不需要重新计算。

换句话说,在生成第 101 个 token 时,模型只需从 KV 缓存中检索前 99 个 token 的已存储数据,并只对第 100 个 token 执行必要的计算。 

02如何估算KV缓存消耗的内存大小?

KV 缓存通常使用 float16 或 bfloat16 数据类型以 16 位的精度存储张量。对于一个 token,KV 缓存会为每一层和每个注意力头存储一对张量(键和值)。

这些张量的大小由注意力头的维度决定,这对张量的总内存消耗(以字节为单位)可以通过以下公式计算: 

层数 × KV 注意力头的数量 × 注意力头的维度 × (位宽 / 8) × 2

最后的 "2" 是因为有两组张量,也就是键和值。位宽通常为 16 位,由于 8 位是 1 字节,因此我们将位宽除以 8,这样在 KV 缓存中每 16 位参数占用 2 个字节。

我们以 Llama 3 8B 为例,这个公式就变为:

32 × 8 × 128 × 2 × 2 = 131,072

注意:Llama 3 8B 有 32 个注意力头,不过由于 GQA 的存在,只有 8 个注意力头用于键和值。

从上面可以看到,对于一个 token,KV 缓存占用 131,072 字节,差不多 0.1 MB。这看起来好像不大,但对于许多不同类型的应用,大模型需要生成成千上万的 tokens。

举个例子,如果我们想利用 Llama 3 8B 的全部 context 大小(8192),KV 缓存将为 8191 个 token 存储键值张量,差不多占用 1.1 G 内存。换句话说,对于一块 24G 显存的消费级 GPU,KV 缓存将占用其总内存的 4.5%。

而对于更大的模型,KV 缓存增长得更快。比如对于 Llama 3 70B,它有 80 层,公式变为:

80 × 8 × 128 × 2 × 2 = 327,680  

对于 8191 个 token,Llama 3 70B 的 KV 缓存将占用 2.7 GB。并且注意,这只是单个序列的内存消耗,如果我们进行批量解码,还需要将这个值乘以  batch size。

比如 batch size=32 的 Llama 3 8B 模型,将需要 35.2 GB 的 GPU 显存,一块消费级 GPU 显然搞不定了。

因此虽然在推理阶段用 KV 缓存可以提高处理速度,并且已经是业界标准做法,但是 KV 缓存在深层模型和长序列场景下,也会占据大量 GPU 内存。

而实际开发中,我们可以通过 KV 缓存量化,来降低推理阶段的 LLM 内存需求。后面我们将通过实际的例子(Llama 3 8B 模型),来看看如何对 KV 缓存进行量化的!

END


我是丁师兄,专注于智能驾驶方向大模型落地,公众号:丁师兄大模型

大模型1v1学习,已帮助多名同学上岸国内外大厂

  • 9
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值