以14B参数量的llama为例,在训练阶段,模型所需要的显存由下列几个部分组成:
- 模型参数存储计算
- 计算公式:参数总量 × 每个参数的存储大小
- 对于14B模型,约需28 GB显存
- 激活值显存计算
- 考虑批次大小、序列长度和模型结构
- 主要包括注意力层和前馈神经网络层
- 最少约需20-30 GB显存
- 优化器状态计算
- Adam优化器通常需要4倍参数显存
- 对于14B模型,约需112 GB显存
- 额外开销
- 预留20-30%额外显存
- 约需30-50 GB显存
总体来说,一个14B参数的模型在训练时可能需要150-200 GB显存。
激活值显存计算略复杂,展开如下:
模型基本参数
- 批次大小:4
- 序列长度:2048
- 隐藏层维度:5120
- 注意力头数:40
- 模型层数:40
- 数据精度:float16(每个数值占2字节)
注意力层激活值计算
- 查询(Q)、键(K)、值(V)矩阵显存
显存 = 批次大小 × 序列长度 × 隐藏层维度 × 3 × 2字节
= 4 × 2048 × 5120 × 3 × 2
= 4 × 2048 × 5120 × 6
= 252,182,528 字节
≈ 240 MB
- 注意力分数矩阵显存
显存 = 批次大小 × 注意力头数 × 序列长度² × 2字节
= 4 × 40 × 2048 × 2048 × 2
= 4 × 40 × 4,194,304 × 2
= 1,342,177,280 字节
≈ 1.25 GB
- 注意力输出矩阵显存
显存 = 批次大小 × 序列长度 × 隐藏层维度 × 2字节
= 4 × 2048 × 5120 × 2
= 4 × 2048 × 10,240
= 84,082,688 字节
≈ 80 MB
前馈神经网络层激活值计算
- 第一个线性变换(隐藏层扩展到中间层)
中间层维度 = 隐藏层维度 × 4 = 5120 × 4 = 20,480
显存 = 批次大小 × 序列长度 × 中间层维度 × 2字节
= 4 × 2048 × 20,480 × 2
= 4 × 2048 × 40,960
= 336,330,752 字节
≈ 320 MB
- 激活函数(通常是GELU)
显存 = 批次大小 × 序列长度 × 中间层维度 × 2字节
= 4 × 2048 × 20,480 × 2
= 336,330,752 字节
≈ 320 MB
- 第二个线性变换(中间层压缩回隐藏层)
显存 = 批次大小 × 序列长度 × 隐藏层维度 × 2字节
= 4 × 2048 × 5120 × 2
= 84,082,688 字节
≈ 80 MB
总体激活值显存计算
- 单层激活值显存
单层显存 = 注意力层显存 + 前馈神经网络层显存
≈ (240 + 1250 + 80) + (320 + 320 + 80) MB
≈ 1570 + 720 MB
≈ 2.29 GB
- 全模型激活值显存
总显存 = 单层显存 × 模型层数
= 2.29 GB × 40
≈ 91.6 GB
关键影响因素解析
-
批次大小:呈线性增长
- 批次大小越大,激活值显存越大
- 4 → 8,显存大致翻倍
-
序列长度:呈平方级增长
- 序列长度增加,显存增长更快
- 2048 → 4096,显存增长约4倍
-
隐藏层维度:呈线性增长
- 维度越大,每个激活值占用显存越多
实践建议
- 动态调整批次大小控制显存
- 使用梯度检查点减少显存消耗
- 考虑混合精度训练
- 根据硬件灵活调整
总结
对于14B参数的LLaMA模型,在典型配置下:
- 单层激活值显存:约2.29 GB
- 全模型激活值显存:约91.6 GB
推理阶段,所需要的显存:
推理阶段显存 = 模型参数显存 + KV Cache显存 + 激活值显存
KV Cache显存 = 批次大小 × 模型层数 × 序列长度 × (键矩阵大小 + 值矩阵大小) × 2字节
单层KV Cache显存 = 1 × 序列长度 × 隐藏层维度 × 2
= 1 × 2048 × 5120 × 2
= 20,971,520字节
≈ 20 MB
总KV Cache显存 = 单层KV Cache显存 × 模型层数
= 20 MB × 40
≈ 800 MB
激活值显存 = 批次大小 × 序列长度 × (注意力层显存 + 前馈网络层显存)
注意力层激活值显存 ≈ 1 × 2048 × 5120 × 2字节
≈ 20 MB
前馈网络层激活值显存 ≈ 1 × 2048 × (5120 × 4) × 2字节
≈ 80 MB
总激活值显存 ≈ 100 MB
总显存 = 模型参数显存 + KV Cache显存 + 激活值显存
= 28 GB + 0.8 GB + 0.1 GB
≈ 29 GB