对于llama3.1 8B模型,FP32和BF16混合精度训练,用的是AdamW优化器,模型训练时占用显存分析

目录

为什么先不考虑激活值的显存占用

1. 模型参数

含义

计算

2. 梯度参数

含义

3. 优化器参数

含义

4. 较固定总显存占用

计算

详细解释

5. 激活值计算:

计算公式

插入数值

计算步骤

结论


显存主要被用在四个模块上:

  • 模型权重本身

  • 梯度

  • 优化器

  • 激活值

其中,

  • 静态:优化器状态、模型参数

  • 动态:激活值、梯度值

也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。

为什么先不考虑激活值的显存占用

在计算显存占用时,我们通常会区分模型参数、梯度参数和优化器状态的显存占用,以及激活值的显存占用。以下是具体原因:

  1. 模型参数、梯度参数和优化器状态

    • 这些部分的显存占用是相对固定的,取决于模型的大小和优化器的选择。
    • 在混合精度训练中,我们可以明确计算这些部分的显存占用。
  2. 激活值

    • 激活值的显存占用与批量大小(batch size)和序列长度(sequence length)密切相关,且在不同的训练任务和配置下变化较大。
    • 激活值的显存占用往往是动态的,取决于具体的训练过程和数据流动。

因此,在讨论显存占用时,我们通常会先计算固定部分(模型参数、梯度参数和优化器状态)的显存占用,而不考虑激活值的显存占用。这是因为激活值的显存占用是高度可变的,需要根据具体的训练配置进行动态调整。具体变化在最后简单介绍一下

在模型训练中,显存占用主要包括模型参数、梯度参数和优化器状态。对于LLaMA 3.1 8B模型,使用混合精度训练(FP32和BF16)和AdamW优化器时,显存占用的计算如下:

1. 模型参数

含义

模型参数是神经网络的权重和偏置等参数。对于8B参数的模型:

  • BF16(Brain Floating Point 16-bit):每个参数占用16位(2字节)
  • FP32(Floating Point 32-bit):每个参数占用32位(4字节)
计算

假设模型的所有参数都存储为BF16和FP32两种格式:

  • BF16:8B参数 * 2字节 = 16GB
  • FP32:8B参数 * 4字节 = 32GB

总的模型参数显存占用为: 16𝐺𝐵+32𝐺𝐵=48𝐺𝐵

2. 梯度参数

含义

梯度参数是用于反向传播更新模型参数的梯度值。在混合精度训练中,梯度通常以BF16格式存储:

  • BF16:8B参数 * 2字节 = 16GB

总的梯度参数显存占用为: 16𝐺𝐵

3. 优化器参数

含义

AdamW优化器需要存储额外的状态参数,包括一阶动量(momentum)和二阶动量(variance)。这些参数通常以FP32格式存储:

  • 一阶动量(FP32):8B参数 * 4字节 = 32GB
  • 二阶动量(FP32):8B参数 * 4字节 = 32GB

总的优化器参数显存占用为: 32𝐺𝐵+32𝐺𝐵=64𝐺𝐵

4. 较固定总显存占用

计算

不考虑激活值的情况下,总显存占用为: 48𝐺𝐵(模型参数)+16𝐺𝐵(梯度参数)+64𝐺𝐵(优化器参数)=128𝐺𝐵

详细解释
  1. 模型参数(48GB)

    • BF16:模型的所有参数以16位格式存储,占用16GB显存。
    • FP32:模型的所有参数以32位格式存储,占用32GB显存。
  2. 梯度参数(16GB)

    • BF16:用于反向传播的梯度参数以16位格式存储,占用16GB显存。
  3. 优化器参数(64GB)

    • 一阶动量(32GB):AdamW优化器的一阶动量参数以32位格式存储,占用32GB显存。
    • 二阶动量(32GB):AdamW优化器的二阶动量参数以32位格式存储,占用32GB显存。

总结来说,在LLaMA 3.1 8B模型的混合精度训练中,模型参数、梯度参数和优化器参数的显存占用分别为48GB、16GB和64GB,总计128GB,不考虑激活值的情况下。

5. 激活值计算:

要计算LLaMA 3.1 8B模型的激活值显存占用,我们需要知道以下信息:

  1. 批量大小(Batch Size, B)
  2. 序列长度(Sequence Length, S)
  3. 每层的输出维度(Hidden Size, d)
  4. 模型的层数(Number of Layers, L)
  5. 每个激活值元素的大小(Element Size, BF16为2字节,FP32为4字节)

假设以下典型配置(请注意,实际配置可能有所不同):

  • 批量大小 𝐵=32
  • 序列长度 𝑆=512
  • 每层的输出维度 𝑑=4096
  • 模型的层数 𝐿=80(假设LLaMA 3.1 8B有80层)
  • 使用BF16格式(每个元素2字节)
计算公式

激活值的总显存占用可以表示为: 显存占用=𝐵×𝑆×∑𝑖=1𝐿(𝑑𝑖×size_of_element)

对于具有相同输出维度 𝑑 的所有层,这个公式简化为: 显存占用=𝐵×𝑆×𝐿×𝑑×size_of_element

插入数值
  • 批量大小 𝐵=32
  • 序列长度 𝑆=512
  • 层数 𝐿=80
  • 每层输出维度 𝑑=4096
  • 每个元素大小(BF16) size_of_element=2字节

计算显存占用: 显存占用=32×512×80×4096×2字节

计算步骤
  1. 计算批量大小和序列长度的乘积: 32×512=16384

  2. 计算层数和输出维度的乘积: 80×4096=327680

  3. 将上述结果相乘并乘以元素大小: 16384×327680×2=10737418240字节

  4. 转换为GB: 10737418240字节=10GB

结论

在假设批量大小为32,序列长度为512,每层输出维度为4096,使用BF16格式的情况下,LLaMA 3.1 8B模型的激活值显存占用大约为10GB。

请注意,这只是一个估算,实际显存占用可能会因为其他因素(如模型具体架构、额外的缓存和中间结果存储等)有所不同。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

samoyan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值