一文讲明白大模型显存占用(只考虑单卡)

知乎:然荻

链接:https://zhuanlan.zhihu.com/p/713256008

纯知识分享,侵删

1.告诉你一个模型的参数量,你要怎么估算出训练和推理时的显存占用?

2.Lora相比于全参训练节省的显存是哪一部分?Qlora相比Lora呢?

3.混合精度训练的具体流程是怎么样的?

这是我曾在面试中被问到的问题,为了巩固相关的知识,打算系统的写一篇文章,帮助自己复习备战秋招的同时,希望也能帮到各位小伙伴。这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。

1.数据精度

想要计算显存,从“原子”层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少bit。我们都知道:

1 byte = 8 bits

1 KB = 1,024 bytes

1 MB = 1,024 KB

1 GB = 1,024 MB

由此可以明白,一个含有1G参数的模型,如果每一个参数都是32bit(4byte),那么直接加载模型就会占用4x1G的显存。

1.1常见的几种精度类型

个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:

7c5e7a24da1f114ec3e3b81d54e75bbc.jpeg

各种精度的数据结构

可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。符号位都是1位(0表示正,1表示负),指数位影响浮点数范围,小数位影响精度。其中TF32并不是有32bit,只有19bit不要记错了。BF16指的是Brain Float 16,由Google Brain团队提出。

1.2 具体计算例子

我硕士话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据:我以BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:

  • 题目:

2e50e3a7a285583adf0a7e5bc822094d.jpeg

随机生成的BF16精度数据

- 先给出具体计算公式:

6576e404d6080e414738c15f2a39afc1.png

- 然后step by step地分析(不是,怎么还对自己使用上Cot了)

符号位Sign = 1,代表是负数

指数位Exponent = 17,中间一坨是677bbbc18bca3feda7f2f499168454a0.png 

小数位Mantissa = 3,后面那一坨是 3ba26860fd4efc297ef66f45bbe969cb.png

  • 最终结果

三个部分乘起来就是最终结果 -8.004646331359449e-34

  • 注意事项

中间唯一需要注意的地方就是指数位是的全0和全1状态是特殊情况,不能用公式,如果想要深入了解可以看这个博客: 彻底搞懂float16与float32的计算方式-CSDN博客 如果感兴趣想更加深入了解如何从FP32转换为BF16的,可以看这个博主的讲解: 从一次面试搞懂 FP16、BF16、TF32、FP32

2.全参训练和推理的显存分析

OK了我们知道了数据精度对应存储的方式和大小, 相当于我们了解了工厂里不同规格的机器零件,但我们还需要了解整个生产线的运作流程,我们才能准确估算出整个工厂(也就是我们的模型训练过程)在运行时所需的资源(显存)。
那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。

2.1混合精度训练

2.1.1 原理介绍

顾名思义,混合精度训练就是将多种不同的精度数据混合在一起训练,《 MIXED PRECISION TRAINING 》这篇论文里将FP16和FP32混合,优化器用的是Adam,如下图所示:

b42ef7b3adf5a415799c3d22b1b35bb1.jpeg

MIXED PRECISION TRAINING论文里的训练流程图

按照训练运行的逻辑来讲:

Step1:优化器会先备份一份FP32精度的模型权重,初始化好FP32精度的一阶和二阶动量(用于更新权重)。

Step2:开辟一块新的存储空间,将FP32精度的模型权重转换为FP16精度的模型权重。

Step3:运行forward和backward,产生的梯度和激活值都用FP16精度存储。

Step4:优化器利用FP16的梯度和FP32精度的一阶和二阶动量去更新备份的FP32的模型权重。

Step5:重复Step2到Step4训练,直到模型收敛。

我们可以看到训练过程中显存主要被用在四个模块上:

  • 模型权重本身(FP32+FP16)

  • 梯度(FP16)

  • 优化器(FP32)

  • 激活值(FP16)

2.1.2 三个小问题

写到这里,我就有3个小问题,第一个问题,为什么不全都用FP16,那不是计算更快、内存更少?

根据我们第一章的知识,我们可以知道FP16精度的范围比FP32窄了很多,这就会产生数据溢出和舍入误差两个问题(想深入了解的,请看全网最全-混合精度训练原理),这会导致梯度消失无法训练,所以我们不能全都用FP16,还需要FP32来进行精度保证。看到这里你也许会想到可以用BF16代替,是的,这也是为什么如今很多训练都是BF16的原因,至少BF16不会产生数据溢出了,业界的实际使用也反馈出比起精度,大模型更在意范围。

第二个问题,为什么我们只对激活值和梯度进行了半精度优化,却新添加了一个FP32精度的模型副本,这样子显存不会更大吗?

答案是不会,激活值和batch_size以及seq_length相关,实际训练的时候激活值对显存的占用会很大,对于激活值的正向优化大于备份模型参数的负向优化,最终的显存是减少的。(这里还可以考虑梯度检查点的优化方法,能更进一步优化激活值的显存,感兴趣可以看看这个大模型高效训练基础知识:梯度检查点(Gradient Checkpointing))。

第三个问题,我们知道显存和内存一样,有静态和动态之分别,那么上面提到的哪些是静态哪些是动态呢?

应该很多人都能猜到:

  • 静态:优化器状态、模型参数

  • 动态:激活值、梯度值

也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。如果想要深度探索,指路[LLM]大模型显存计算公式与优化

41f3a917c0c0b9e9182e7e72156b9e31.jpeg

动态监控显存图

2.1.3 来个测试吧!

写到这里,我们应该对于分析大模型训练时候的显存问题应该不在话下了(除了动态部分),那么我们就来实测一下,正在阅读的小伙伴也可以先自己尝试计算一下,看看是不是真的懂了。对于llama3.1 8B模型,FP32和BF16混合精度训练,用的是AdamW优化器,请问模型训练时占用显存大概为多少?

解:

模型参数:16(BF16) + 32(PF32)= 48G

梯度参数:16(BF16)= 16G

优化器参数:32(PF32) + 32(PF32)= 64G

不考虑激活值的情况下,总显存大约占用 (48 + 16 + 64) = 128G

2.2 推理与KV Cache

2.2.1 原理理解

推理的时候,显存几乎只考虑模型参数本身,除此之外就是现在广泛使用的KV cache也会占用显存。KV cache与之前讲的如何减少显存不一样,KV cache的目的是减少延迟,也就是为了推理的速度牺牲显存。

具体KV cache是什么我就不展开讲了,我贴一张动图就可以非常清晰地明白了(如果还不明白可以去看大模型推理加速:看图学KV Cache),记住一点,我们推理就是在不断重复地做”生成下一个token“的任务,生成当前token 仅仅与当前的QKV和之前所有KV有关,那么我们就可以去维护这个KV并不断更新。

45dc4a488acca672e2a9e283c0688014.jpeg

KV Cache动态实现

顺便回答一个很多小白经常会问的问题,为什么没有Q Cache呢?

因为生成当前的token只依赖当前的Q,那为什么生成当前的token只依赖当前的Q呢,因为Self-Attention的公式决定的,S代表Softmax激活函数:

4135f166dec9ceb9b6e0308d47352c1b.png

我们可以看到,在序列t的位置,也就是第t行,只跟𝑄𝑡有关系,也就是说,Attention的计算公式就决定了我们不需要保存每一步的Q,再深入地说,矩阵乘法的数学特性决定了我们不需要保存每一步的Q。

2.2.2 计算KV Cache显存

如何计算KV Cache的显存是我这篇文章想要关心的事情,先给出公式:

c0d1fe870b2532bc950436dffb76ce55.png

前面的四个参数相乘应该很好理解,就是KV对应在模型每一层的所有隐藏向量的总和,第一个2指的是KV两部分,第二个2指的是半精度对应的字节数。

举个栗子,对于llama7B,hiddensize = 4096,seqlength = 2048 , batchsize = 64,layers = 32 计算得到

959e7d606424d67691f24699c3b7c93d.png

可以看到,KV Cache在大批量长句子的情况下,显存占用率也是很大的。

68G看着是相对模型本身很大,但这是在batch很大的情况下,在单batch下,KV Cache就仅占有 1G左右的显存了,就仅仅占用模型参数一半的显存。

2.2.3 MQA和GQA

什么,你觉得KV Cache用的显存还是太多了,不错,对于推理落地侧,再怎么严苛要求也是合理的,MQA和GQA就是被用来进一步减少显存的方法,现在的大模型也几乎都用到了这个方法,我们就来讲一讲。

4d6a1d8a92befef46be64ff8ac68b775.jpeg

三种KV处理方式

其实方法不难理解,看这张图一目了然,关键词就是“共享多头KV”,很朴素的删除模型冗余结构的思路。最左侧就是最基础的MHA多头自注意力,中间的GQA就是保留几组KV头,右侧MQA就是只保留1组KV头,目前用的比较多的是GQA,降低显存提速的同时也不会太过于影响性能。如果没看懂的小伙伴可以去大模型推理加速:KV Cache 和 GQA详细看看,这里就不展开讲了,我想讲的是具体显存的变化。

上一小节我们知道MHA的KV Cache占用显存的计算公式是

96fd8178cc57c7b4a019710f5b67ed43.png

有一个小细节,可以重头开始训练MQA 和 GQA的模型,也可以像 GQA 论文里面一样基于开源模型,修改模型结构后继续预训练。目前基本上都是从头开始训练的,因为要保持训练和推理的模型结构一致。

3.Lora和Qlora显存分析

上面两章详细对全参微调训练和推理进行了显存分析,聪明的小伙伴就发现了一个问题,现在都用PEFT(高效参数微调)了,谁有那么多资源全参训练啊推理阶段也是要量化的,这样又该怎么进行显存分析呢。那么我们这一章就来解决这个问题,我相信完全理解前两章的小伙伴理解起来会非常轻松,所谓的显存分析,只要知道了具体的流程和数据精度,那么分析的方法都是类似的。OK,我们将会在这一章里详细分析目前前业界最火的Lora和Qlora方法的显存占用情况,中间也会涉及到相关的原理知识,冲!

3.1 Lora

能看到这里的人,我想对于Lora的原理应该都很了解了,就浅浅提一下,如下图所示,就是在原来的权重矩阵的旁路新建一对低秩的可训练权重,训练的时候只训练旁路,大大降低了训练的权重数量,参数量d*d降为2*d*r

4e5bcc520b5fcc8169bbaef4e068c1cc.jpeg

Lora原理图

有了前面的全参情况下训练的显存分析,现在分析起来就比较通顺了,我们一步一步来,还是以BF16半精度模型Adamw优化器训练为例子,lora部分的参数精度也是BF16,并且设1字节模型参数对应的显存大小φ。

首先是模型权重本身的权重,这个肯定是要加载原始模型和lora旁路模型的,因为lora部分占比小于2个数量级,所以显存分析的时候忽略不计,显存占用2φ。

然后就是优化器部分,优化器也不需要对原模型进行备份了,因为优化器是针对于需要更新参数的模型权重部分进行处理,也就是说优化器只包含Lora模型权重相关的内容,考虑到数量级太小,也忽略不计,故优化器部分占用显存0φ。

其实容易搞错混淆的部分就是梯度的显存了,我看了不少的博客文章,有说原始模型也要参与反向传播,所以是要占用一份梯度显存的,也有的说原始模型都不更新梯度,肯定只需要Lora部分的梯度显存,搞得我头很大。那么究竟正确答案是哪一种呢,这里直接给出答案,不需要计算原始模型部分的梯度,也基本不占用显存。也就是说梯度部分占用显存也可以近似为0φ。想深入探究的可以去大模型高效微调-LoRA原理详解和训练过程深入分析。

总的来说,不考虑激活值的情况下,Lora微调训练的显存占用只有2φ,一个7B的模型Lora训练只需要占用显存大约14G左右。验证一下,我们来看Llama Factory里给出训练任务的显存预估表格:

0aa1a76616e22de3878861a52e3f5a37.jpeg

Llama Factory的表格

可以看到7B模型的Lora训练的显存消耗与我们估计得也差不多,同时也还可以复习一下全参训练、混合精度训练的显存分析,也是基本符合我们之前的分析的。

3.2 QLora

上面Llama Factory的那张表也是稍微剧透了一下我们接下来要讲的内容,也就是QLora,继Lora之后也是在业界落地非常广泛通用的一种大模型PEFT方法。QLora,也叫做量化Lora,顾名思义,也就是进一步压缩模型的精度,然后用Lora训练,他的核心思路很好理解,但实际上涉及的知识点细节却并不少。我同样也不会太过深入地去介绍这个中细节,如想深入了解可以去看论文或者其他博客(指路QLoRA、GPTQ:模型量化概述,QLoRA(Quantized LoRA)详解),我主要是想按照显存占用的思路去分析Qlora,理解思路永远比死的知识点更加重要。

3.2.1 Qlora的整体思路

Qlora来自于《 QLORA: Efficient Finetuning of Quantized LLMs 》这篇论文,实际上这篇论文的核心在于提出了一种新的量化方法,重点在于量化而不是Lora。很多不了解的人看到量化lora这个名字就以为是对Lora部分的参数进行量化,因为他们认为毕竟只有Lora部分的参数参与了训练,但理解了上面一节的小伙伴就明白实际并不是这样,原始模型的本身参数虽然不更新参数,但是仍然需要前向和反向传播,QLora优化的正是Lora里显存占大头的模型参数本身。

那么Qlora就是把原始模型参数从16bit压缩到4bit,然后更新这个4bit参数吗?非也非也,这里需要区分两个概念,一个是计算参数,一个是存储参数,计算参数就是在前向、反向传播参与实际计算的参数,存储参数就是不参与计算一开始加载的原始参数。QLora的方法就是,加载并且量化16bit的模型原始参数为4bit作为存储参数,但是在具体需要计算的时候,将该部分的4bit参数反量化为16bit作为计算参数。也就是说,QLora实际上我们训练计算里用到的所有数据的精度都是和Lora一样的,只是加载的模型是4bit,会进行一个反量化到16bit的方法,用完即释放。前面说到的都是模型原始参数本身,不包括lora部分的参数,Lora部分的参数不需要量化,一直都是16bit。看到这里机智的你应该也想到了,这比Lora多了一个量化反量化的操作,那训练时间是不是会更长,没错一般来讲Qlora训练会比Lora多用30%左右的时间。

3.2.2 Qlora的技术细节

基本的思路讲完了,那么其中包含了哪些具体的实现细节呢?Qlora主要包括三个创新点,这里我只简单提及,应付面试足够的程度,如果想要详细了解可以去看论文:

  1. NF4量化:常见的量化分布都是基于参数是均匀分布的假设,而这个方法基于参数是正态分布的假设,这样使得量化精度大大提升。

  2. 双重量化:对于第一次量化后得到的用于计算反量化时的锚点参数,我们对这个锚点参数进行量化,可以进一步降低显存。

  3. 优化器分页:为了防止OOM,可以在GPU显存紧张的时候利用CPU内存进行加载参数。

3.2.3 显存分析

想必已经理解QLora运行思路的小伙伴,应该可以很轻松的分析出Qlora占用显存的部分了吧,这就是理清楚思路的好处。没错,Qlora占用的显存主要就是4Bit量化后的模型本身也就是0.5φ,这里没有考虑少量的Lora部分的参数和量化计算中可能产生的显存。可以回过头去看看刚才的表格,也是基本符合预期的。

最后我们用一个表格来总结所有之前我们提到的显存分析:

部分显存对应精度(训练)全参微调(全FP16)全参微调(BF16混合精度)LoraQLora
主干模型(模型存储/计算参数)FP16/FP16BF16/BF16BF16/BF16NF4/BF16
主干模型(梯度)FP16BF16NullNull
主干模型(adamw优化器)2 x FP163 x FP32NullNull
LoRA部分(可忽略不计)NullNullBF16BF16
总和(大约)8Byte16Byte2Byte0.5Byte

第一次写文章,不足之处多多见谅,也是参考了很多知乎大佬的文章,主要是做一个思路的整理,有问题或者写得不对的地方欢迎各位小伙伴在评论区讨论指正。


备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群

1cbd546e1a07d11a237ee037f68feca2.png

id:DLNLPer,记得备注呦

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值