offer捷报
恭喜训练营的同学,成功斩获字节的 offer。
告诉你一个模型的参数量,你要怎么估算出训练和推理时的显存占用? Lora 相比于全参训练节省的显存是哪一部分?Qlora 相比 Lora 呢? 混合精度训练的具体流程是怎么样的?
这是我曾在面试中被问到的问题,为了巩固相关的知识,写下这篇文章,帮助自己复习备战面试的同时,希望也能帮到各位小伙伴。
这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。
01
数据精度
想要计算显存,从“原子”层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少 bit。
我们都知道:
-
1 byte = 8 bits
-
1 KB = 1,024 bytes
-
1 MB = 1,024 KB
-
1 GB = 1,024 MB
由此可以明白,一个含有 1G 参数的模型,如果每一个参数都是 32bit(4byte),那么直接加载模型就会占用 4x1G 的显存。
(1)常见的几种精度类型
个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:
各种精度的数据结构
可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。
符号位都是 1 位(0 表示正,1 表示负),指数位影响浮点数范围,小数位影响精度。
其中 TF32 并不是有 32bit,只有 19bit 不要记错了。BF16 指的是 Brain Float 16,由 Google Brain 团队提出。
(2)具体计算例子
讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据。
我以 BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:
题目:
随机生成的 BF16 精度数据
先给出具体计算公式:
然后 step by step 地分析(不是,怎么还对自己使用上 Cot 了)。
符号位 Sign = 1,代表是负数。
最终结果:三个部分乘起来就是最终结果 -8.004646331359449e-34。
注意事项:中间唯一需要注意的地方就是指数位是的全 0 和全 1 状态是特殊情况,不能用公式。
02
全参训练和推理的显存分析
我们知道了数据精度对应存储的方式和大小, 相当于我们了解了工厂里不同规格的机器零件,但我们还需要了解整个生产线的运作流程,我们才能准确估算出整个工厂(也就是我们的模型训练过程)在运行时所需的资源(显存)。
那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。
(1)混合精度训练
原理介绍
顾名思义,混合精度训练就是将多种不同的精度数据混合在一起训练,《 MIXED PRECISION TRAINING 》这篇论文里将 FP16 和 FP32 混合,优化器用的是 Adam,如下图所示:
MIXED PRECISION TRAINING论文里的训练流程图
按照训练运行的逻辑来讲:
-
**Step1:**优化器会先备份一份 FP32 精度的模型权重,初始化好 FP32 精度的一阶和二阶动量(用于更新权重)。
-
**Step2:**开辟一块新的存储空间,将 FP32 精度的模型权重转换为 FP16 精度的模型权重。
-
**Step3:**运行 forward 和 backward,产生的梯度和激活值都用 FP16 精度存储。
-
**Step4:**优化器利用 FP16 的梯度和 FP32 精度的一阶和二阶动量去更新备份的 FP32 的模型权重。
-
**Step5:**重复 Step2 到 Step4 训练,直到模型收敛。
我们可以看到训练过程中显存主要被用在四个模块上:
-
模型权重本身(FP32+FP16)
-
梯度(FP16)
-
优化器(FP32)
-
激活值(FP16)
三个小问题
写到这里,我就有 3 个小问题,第一个问题,为什么不全都用 FP16,那不是计算更快、内存更少?
根据我们第一章的知识,我们可以知道FP16精度的范围比FP32窄了很多,这就会产生数据溢出和舍入误差两个问题,这会导致梯度消失无法训练,所以我们不能全都用 FP16,还需要 FP32 来进行精度保证。
看到这里你也许会想到可以用 BF16 代替,是的,这也是为什么如今很多训练都是 BF16 的原因,至少 BF16 不会产生数据溢出了,业界的实际使用也反馈出比起精度,大模型更在意范围。
第二个问题,为什么我们只对激活值和梯度进行了半精度优化,却新添加了一个 FP32 精度的模型副本,这样子显存不会更大吗?
答案是不会,激活值和 batch_size 以及 seq_length 相关,实际训练的时候激活值对显存的占用会很大,对于激活值的正向优化大于备份模型参数的负向优化,最终的显存是减少的。
第三个问题,我们知道显存和内存一样,有静态和动态之分别,那么上面提到的哪些是静态哪些是动态呢?
应该很多人都能猜到:
-
**静态:**优化器状态、模型参数
-
**动态:**激活值、梯度值
也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。
动态监控显存图
来个测试吧!
写到这里,我们应该对于分析大模型训练时候的显存问题应该不在话下了(除了动态部分),那么我们就来实测一下,正在阅读的小伙伴也可以先自己尝试计算一下,看看是不是真的懂了。
对于 llama3.1 8B 模型,FP32 和 BF16 混合精度训练,用的是 AdamW 优化器,请问模型训练时占用显存大概为多少?
解:
-
模型参数:16(BF16) + 32(PF32)= 48G
-
梯度参数:16(BF16)= 16G
-
优化器参数:32(PF32) + 32(PF32)= 64G
不考虑激活值的情况下,总显存大约占用 (48 + 16 + 64) = 128G。
(2)推理与 KV Cache
原理理解
推理的时候,显存几乎只考虑模型参数本身,除此之外就是现在广泛使用的 KV cache 也会占用显存。
KV cache 与之前讲的如何减少显存不一样,KV cache 的目的是减少延迟,也就是为了推理的速度牺牲显存。
具体 KV cache 是什么我就不展开讲了,我贴一张动图就可以非常清晰地明白了。
记住一点,我们推理就是在不断重复地做”生成下一个token“的任务,生成当前 token 仅仅与当前的 QKV 和之前所有 KV 有关,那么我们就可以去维护这个 KV 并不断更新。
KV Cache 动态实现
顺便回答一个很多小白经常会问的问题,为什么没有 Q Cache 呢?
因为生成当前的 token 只依赖当前的 Q,那为什么生成当前的 token 只依赖当前的 Q 呢?
因为 Self-Attention 的公式决定的,S 代表 Softmax 激活函数:
我们可以看到,在序列t的位置,也就是第 t 行,只跟 Qt 有关系。
也就是说,Attention 的计算公式就决定了我们不需要保存每一步的 Q,再深入地说,矩阵乘法的数学特性决定了我们不需要保存每一步的 Q。
计算 KV Cache 显存
如何计算 KV Cache 的显存是我这篇文章想要关心的事情。
先给出公式:
前面的四个参数相乘应该很好理解,就是 KV 对应在模型每一层的所有隐藏向量的总和,第一个 2 指的是 KV 两部分,第二个 2 指的是半精度对应的字节数。
举个栗子,对于 llama7B,hiddensize = 4096,seqlength = 2048 , batchsize = 64,layers = 32。
计算得到:
可以看到,KV Cache 在大批量长句子的情况下,显存占用率也是很大的。
68G 看着是相对模型本身很大,但这是在 batch 很大的情况下,在单 batch 下,KV Cache 就仅占有 1G 左右的显存了,就仅仅占用模型参数一半的显存。
MQA 和 GQA
什么,你觉得 KV Cache 用的显存还是太多了,不错,对于推理落地侧,再怎么严苛要求也是合理的,MQA 和 GQA 就是被用来进一步减少显存的方法,现在的大模型也几乎都用到了这个方法,我们就来讲一讲。
三种 KV 处理方式
其实方法不难理解,看这张图一目了然,关键词就是“共享多头 KV”,很朴素的删除模型冗余结构的思路。
最左侧就是最基础的 MHA 多头自注意力,中间的 GQA 就是保留几组 KV 头,右侧 MQA 就是只保留 1 组 KV 头,目前用的比较多的是 GQA,降低显存提速的同时也不会太过于影响性能。
这里就不展开讲了,我想讲的是具体显存的变化。
上一小节我们知道 MHA 的 KV Cache 占用显存的计算公式是:
那么不难理解,MQA 占用显存公式则是:
GQA 占用显存公式则是:
公式中改变的就是共享的头数。
有一个小细节,可以重头开始训练 MQA 和 GQA 的模型,也可以像 GQA 论文里面一样基于开源模型,修改模型结构后继续预训练。目前基本上都是从头开始训练的,因为要保持训练和推理的模型结构一致。
03
Lora和Qlora显存分析
上面两章详细对全参微调训练和推理进行了显存分析,聪明的小伙伴就发现了一个问题,现在都用 PEFT(高效参数微调)了,谁有那么多资源全参训练啊推理阶段也是要量化的,这样又该怎么进行显存分析呢。
那么我们这一章就来解决这个问题,我相信完全理解前两章的小伙伴理解起来会非常轻松,所谓的显存分析,只要知道了具体的流程和数据精度,那么分析的方法都是类似的。
OK,我们将会在这一章里详细分析目前前业界最火的 Lora 和 Qlora 方法的显存占用情况,中间也会涉及到相关的原理知识,冲!
(1)Lora
能看到这里的人,我想对于 Lora 的原理应该都很了解了,就浅浅提一下,如下图所示:
就是在原来的权重矩阵的旁路新建一对低秩的可训练权重,训练的时候只训练旁路,大大降低了训练的权重数量,参数量从 d*d 降为 2*d*r。
有了前面的全参情况下训练的显存分析,现在分析起来就比较通顺了,我们一步一步来,还是以 BF16 半精度模型 Adamw 优化器训练为例子,lora 部分的参数精度也是 BF16,并且设 1 字节模型参数对应的显存大小 Φ。
首先是模型权重本身的权重,这个肯定是要加载原始模型和 lora 旁路模型的,因为 lora 部分占比小于 2 个数量级,所以显存分析的时候忽略不计,显存占用 2Φ。
然后就是优化器部分,优化器也不需要对原模型进行备份了,因为优化器是针对于需要更新参数的模型权重部分进行处理。
也就是说优化器只包含 Lora 模型权重相关的内容,考虑到数量级太小,也忽略不计,故优化器部分占用显存 0Φ。
其实容易搞错混淆的部分就是梯度的显存了,我看了不少的博客文章,有说原始模型也要参与反向传播,所以是要占用一份梯度显存的,也有的说原始模型都不更新梯度,肯定只需要 Lora 部分的梯度显存,搞得我头很大。
那么究竟正确答案是哪一种呢,这里直接给出答案,不需要计算原始模型部分的梯度,也基本不占用显存。也就是说梯度部分占用显存也可以近似为 0Φ。
总的来说,不考虑激活值的情况下,Lora 微调训练的显存占用只有 2Φ,一个 7B 的模型 Lora 训练只需要占用显存大约 14G 左右。
验证一下,我们来看 Llama Factory 里给出训练任务的显存预估表格:
Llama Factory 的表格
可以看到 7B 模型的 Lora 训练的显存消耗与我们估计得也差不多,同时也还可以复习一下全参训练、混合精度训练的显存分析,也是基本符合我们之前的分析的。
(2)QLora
上面 Llama Factory 的那张表也是稍微剧透了一下我们接下来要讲的内容,也就是 QLora,继 Lora 之后也是在业界落地非常广泛通用的一种大模型 PEFT 方法。
QLora,也叫做量化 Lora,顾名思义,也就是进一步压缩模型的精度,然后用 Lora 训练,他的核心思路很好理解,但实际上涉及的知识点细节却并不少。
我同样也不会太过深入地去介绍这个中细节,我主要是想按照显存占用的思路去分析 Qlora,理解思路永远比死的知识点更加重要。
Qlora 的整体思路
Qlora 来自于《 QLORA: Efficient Finetuning of Quantized LLMs 》这篇论文,实际上这篇论文的核心在于提出了一种新的量化方法,重点在于量化而不是 Lora。
很多不了解的人看到量化 lora 这个名字就以为是对 Lora 部分的参数进行量化,因为他们认为毕竟只有 Lora 部分的参数参与了训练。
但理解了上面一节的小伙伴就明白实际并不是这样,原始模型的本身参数虽然不更新参数,但是仍然需要前向和反向传播,QLora 优化的正是 Lora 里显存占大头的模型参数本身。
那么 Qlora 就是把原始模型参数从 16bit 压缩到 4bit,然后更新这个 4bit 参数吗?
非也非也,这里需要区分两个概念,一个是计算参数,一个是存储参数,计算参数就是在前向、反向传播参与实际计算的参数,存储参数就是不参与计算一开始加载的原始参数。
QLora 的方法就是,加载并且量化 16bit 的模型原始参数为 4bit 作为存储参数,但是在具体需要计算的时候,将该部分的 4bit 参数反量化为 16bit 作为计算参数。
也就是说,QLora 实际上我们训练计算里用到的所有数据的精度都是和 Lora 一样的,只是加载的模型是 4bit,会进行一个反量化到 16bit 的方法,用完即释放。
前面说到的都是模型原始参数本身,不包括 lora 部分的参数,Lora 部分的参数不需要量化,一直都是 16bit。
看到这里机智的你应该也想到了,这比 Lora 多了一个量化反量化的操作,那训练时间是不是会更长,没错一般来讲 Qlora 训练会比 Lora 多用 30% 左右的时间。
Qlora 的技术细节
基本的思路讲完了,那么其中包含了哪些具体的实现细节呢?
Qlora 主要包括三个创新点,这里我只简单提及,应付面试足够的程度,如果想要详细了解可以去看论文。
**NF4 量化:**常见的量化分布都是基于参数是均匀分布的假设,而这个方法基于参数是正态分布的假设,这样使得量化精度大大提升。
**双重量化:**对于第一次量化后得到的用于计算反量化时的锚点参数,我们对这个锚点参数进行量化,可以进一步降低显存。
**优化器分页:**为了防止 OOM,可以在 GPU 显存紧张的时候利用 CPU 内存进行加载参数。
显存分析
想必已经理解 QLora 运行思路的小伙伴,应该可以很轻松的分析出 Qlora 占用显存的部分了吧,这就是理清楚思路的好处。
没错,Qlora 占用的显存主要就是 4Bit 量化后的模型本身也就是 0.5Φ,这里没有考虑少量的 Lora 部分的参数和量化计算中可能产生的显存。可以回过头去看看刚才的表格,也是基本符合预期的。
最后我们用一个表格来总结所有之前我们提到的显存分析:
第一次写文章,不足之处多多见谅,也是参考了很多大佬的文章,主要是做一个思路的整理,有问题或者写得不对的地方欢迎各位小伙伴在评论区讨论指正。
如何学习大模型 AI ?
由于新岗位的生产效率,要优于被取代岗位的生产效率,所以实际上整个社会的生产效率是提升的。
但是具体到个人,只能说是:
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
第一阶段(10天):初阶应用
该阶段让大家对大模型 AI有一个最前沿的认识,对大模型 AI 的理解超过 95% 的人,可以在相关讨论时发表高级、不跟风、又接地气的见解,别人只会和 AI 聊天,而你能调教 AI,并能用代码将大模型和业务衔接。
- 大模型 AI 能干什么?
- 大模型是怎样获得「智能」的?
- 用好 AI 的核心心法
- 大模型应用业务架构
- 大模型应用技术架构
- 代码示例:向 GPT-3.5 灌入新知识
- 提示工程的意义和核心思想
- Prompt 典型构成
- 指令调优方法论
- 思维链和思维树
- Prompt 攻击和防范
- …
第二阶段(30天):高阶应用
该阶段我们正式进入大模型 AI 进阶实战学习,学会构造私有知识库,扩展 AI 的能力。快速开发一个完整的基于 agent 对话机器人。掌握功能最强的大模型开发框架,抓住最新的技术进展,适合 Python 和 JavaScript 程序员。
- 为什么要做 RAG
- 搭建一个简单的 ChatPDF
- 检索的基础概念
- 什么是向量表示(Embeddings)
- 向量数据库与向量检索
- 基于向量检索的 RAG
- 搭建 RAG 系统的扩展知识
- 混合检索与 RAG-Fusion 简介
- 向量模型本地部署
- …
第三阶段(30天):模型训练
恭喜你,如果学到这里,你基本可以找到一份大模型 AI相关的工作,自己也能训练 GPT 了!通过微调,训练自己的垂直大模型,能独立训练开源多模态大模型,掌握更多技术方案。
到此为止,大概2个月的时间。你已经成为了一名“AI小子”。那么你还想往下探索吗?
- 为什么要做 RAG
- 什么是模型
- 什么是模型训练
- 求解器 & 损失函数简介
- 小实验2:手写一个简单的神经网络并训练它
- 什么是训练/预训练/微调/轻量化微调
- Transformer结构简介
- 轻量化微调
- 实验数据集的构建
- …
第四阶段(20天):商业闭环
对全球大模型从性能、吞吐量、成本等方面有一定的认知,可以在云端和本地等多种环境下部署大模型,找到适合自己的项目/创业方向,做一名被 AI 武装的产品经理。
- 硬件选型
- 带你了解全球大模型
- 使用国产大模型服务
- 搭建 OpenAI 代理
- 热身:基于阿里云 PAI 部署 Stable Diffusion
- 在本地计算机运行大模型
- 大模型的私有化部署
- 基于 vLLM 部署大模型
- 案例:如何优雅地在阿里云私有部署开源大模型
- 部署一套开源 LLM 项目
- 内容安全
- 互联网信息服务算法备案
- …
学习是一个过程,只要学习就会有挑战。天道酬勤,你越努力,就会成为越优秀的自己。
如果你能在15天内完成所有的任务,那你堪称天才。然而,如果你能完成 60-70% 的内容,你就已经开始具备成为一名大模型 AI 的正确特征了。