影响因素概述
模型训练框架
- PyTorch:CUDA context占用几百MB显存,与版本有关。
- 模型参数大小:例如7B模型以FP16格式占用14GB显存。
- 临时Tensor:前向计算中产生,用于反向传播。
- 梯度:反向传播计算得到。
- 优化器状态:
- 全量微调:梯度与参数一样大。
- 普通SGD:无动量。
- 一阶动量优化器:如momentum-SGD,参数大小与模型一样。
- 二阶动量优化器:如Adam,参数大小为模型两倍。
前向计算临时Tensor显存占用
self-attention显存占用
-
输入矩阵I:形状[b, s, d],显存占用2bsd bytes。
-
Q, K, V:形状[b, s, d],QKT占用4bsd bytes。
-
softmax:形状[b, h, s, s],显存占用2bhs2 bytes。
-
dropout:mask矩阵形状[b, h, s, s],显存占用bhs2 bytes。
-
score * V:形状[b, h, s, s],显存占用2bhs2 bytes。
-
WO:形状[b, s, d],显存占用2bsd bytes。
-
dropout:mask矩阵形状[b, s, d],显存占用bsd bytes。
总计:11bsd + 5bhs2 bytes。
MLP显存占用
-
线性层:
- 第一个:形状[b, s, d],显存占用2bsd bytes。
- 第二个:形状[b, s, 4d],显存占用8bsd bytes。
-
激活函数:形状[b, s, 4d],显存占用8bsd bytes。
-
dropout:mask矩阵形状[b, s, d],显存占用bsd bytes。
总计:19bsd bytes。
梯度和优化器显存占用
模型训练过程
- 混合精度训练:
- 前向传递和反向传播使用float16,计算梯度。
- 优化器更新时使用float32。
- 每个参数总计20bytes,总计显存占用20P bytes。
- 普通训练:
- 所有步骤使用float32。
- 每个参数总计24bytes,总计显存占用24P bytes。
模型推理过程
- float16:显存占用约2P bytes。
- float32:显存占用约4P bytes。
参考文章:知乎文章