大模型训练和推理所需的显存计算

以14B参数量的llama为例,在训练阶段,模型所需要的显存由下列几个部分组成:

  1. 模型参数存储计算
    • 计算公式:参数总量 × 每个参数的存储大小
    • 对于14B模型,约需28 GB显存
  2. 激活值显存计算
    • 考虑批次大小、序列长度和模型结构
    • 主要包括注意力层和前馈神经网络层
    • 最少约需20-30 GB显存
  3. 优化器状态计算
    • Adam优化器通常需要4倍参数显存
    • 对于14B模型,约需112 GB显存
  4. 额外开销
    • 预留20-30%额外显存
    • 约需30-50 GB显存

总体来说,一个14B参数的模型在训练时可能需要150-200 GB显存。

激活值显存计算略复杂,展开如下:

模型基本参数
  • 批次大小:4
  • 序列长度:2048
  • 隐藏层维度:5120
  • 注意力头数:40
  • 模型层数:40
  • 数据精度:float16(每个数值占2字节)
注意力层激活值计算
  1. 查询(Q)、键(K)、值(V)矩阵显存
显存 = 批次大小 × 序列长度 × 隐藏层维度 × 3 × 2字节
     = 4 × 2048 × 5120 × 3 × 2
     = 4 × 2048 × 5120 × 6
     = 252,182,528 字节
     ≈ 240 MB
  1. 注意力分数矩阵显存
显存 = 批次大小 × 注意力头数 × 序列长度² × 2字节
     = 4 × 40 × 2048 × 2048 × 2
     = 4 × 40 × 4,194,304 × 2
     = 1,342,177,280 字节
     ≈ 1.25 GB
  1. 注意力输出矩阵显存
显存 = 批次大小 × 序列长度 × 隐藏层维度 × 2字节
     = 4 × 2048 × 5120 × 2
     = 4 × 2048 × 10,240
     = 84,082,688 字节
     ≈ 80 MB
前馈神经网络层激活值计算
  1. 第一个线性变换(隐藏层扩展到中间层)
中间层维度 = 隐藏层维度 × 4 = 5120 × 4 = 20,480

显存 = 批次大小 × 序列长度 × 中间层维度 × 2字节
     = 4 × 2048 × 20,480 × 2
     = 4 × 2048 × 40,960
     = 336,330,752 字节
     ≈ 320 MB
  1. 激活函数(通常是GELU)
显存 = 批次大小 × 序列长度 × 中间层维度 × 2字节
     = 4 × 2048 × 20,480 × 2
     = 336,330,752 字节
     ≈ 320 MB
  1. 第二个线性变换(中间层压缩回隐藏层)
显存 = 批次大小 × 序列长度 × 隐藏层维度 × 2字节
     = 4 × 2048 × 5120 × 2
     = 84,082,688 字节
     ≈ 80 MB
总体激活值显存计算
  1. 单层激活值显存
单层显存 = 注意力层显存 + 前馈神经网络层显存
         ≈ (240 + 1250 + 80) + (320 + 320 + 80) MB
         ≈ 1570 + 720 MB
         ≈ 2.29 GB
  1. 全模型激活值显存
总显存 = 单层显存 × 模型层数
       = 2.29 GB × 40
       ≈ 91.6 GB

关键影响因素解析

  1. 批次大小:呈线性增长

    • 批次大小越大,激活值显存越大
    • 4 → 8,显存大致翻倍
  2. 序列长度:呈平方级增长

    • 序列长度增加,显存增长更快
    • 2048 → 4096,显存增长约4倍
  3. 隐藏层维度:呈线性增长

    • 维度越大,每个激活值占用显存越多

实践建议

  1. 动态调整批次大小控制显存
  2. 使用梯度检查点减少显存消耗
  3. 考虑混合精度训练
  4. 根据硬件灵活调整

总结

对于14B参数的LLaMA模型,在典型配置下:

  • 单层激活值显存:约2.29 GB
  • 全模型激活值显存:约91.6 GB

推理阶段,所需要的显存:

推理阶段显存 = 模型参数显存 + KV Cache显存 + 激活值显存

KV Cache显存 = 批次大小 × 模型层数 × 序列长度 × (键矩阵大小 + 值矩阵大小) × 2字节

单层KV Cache显存 = 1 × 序列长度 × 隐藏层维度 × 2
                  = 1 × 2048 × 5120 × 2
                  = 20,971,520字节
                  ≈ 20 MB

总KV Cache显存 = 单层KV Cache显存 × 模型层数
               = 20 MB × 40
               ≈ 800 MB

激活值显存 = 批次大小 × 序列长度 × (注意力层显存 + 前馈网络层显存)

注意力层激活值显存 ≈ 1 × 2048 × 5120 × 2字节
                   ≈ 20 MB

前馈网络层激活值显存 ≈ 1 × 2048 × (5120 × 4) × 2字节
                     ≈ 80 MB

总激活值显存 ≈ 100 MB

总显存 = 模型参数显存 + KV Cache显存 + 激活值显存
       = 28 GB + 0.8 GB + 0.1 GB
       ≈ 29 GB

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值