1. 不同操作使用的数据类型
a) 模型参数和激活值: FP16
- 原因: 减少内存使用,允许更大批量或更大模型
- 优势: FP16只需FP32一半的内存空间
b) 梯度计算: FP16
- 原因: 加速反向传播,减少内存使用
- 注意: 需要使用损失缩放防止梯度消失
c) 主要计算(如矩阵乘法): FP16
- 原因: 提高计算速度,减少内存带宽需求
- 优势: 现代GPU对FP16运算有硬件级优化
d) 关键操作(损失计算和权重更新): FP32
- 原因: 保持数值稳定性和训练精度
- 重要性: 这些操作对训练收敛至关重要,需要高精度
2. 权重更新过程
- 梯度以FP16计算和存储
- 使用损失缩放防止FP16梯度下溢
- 更新前,FP16梯度转换为FP32
- 权重更新在FP32精度下进行
- 维护一份FP32"主权重"
- 更新后的FP32权重转回FP16用于下一轮计算
3. FP32主权重的内存考虑
- 虽然保存了FP32主权重,但混合精度训练仍有显著内存优势
- 主要内存节省来自激活值和梯度的FP16存储
- 通常可减少30-50%的总体内存使用
4. 优化器状态使用FP32存储的原因
- 提高数值稳定性,尤其对于累积的小数值
- 允许表示更精确的小数值,确保微小更新不被忽略
- 避免长期训练中的下溢和上溢问题
- 与FP32权重更新保持一致性
- 保证优化算法的行为与原始设计一致
- 优化器状态的内存占用相对较小,影响不大
5. 混合精度训练的整体策略
- 在内存效率、计算速度和训练稳定性之间取得平衡
- 大部分操作使用FP16以提高效率
- 关键操作保持FP32以确保精度和稳定性
- 需要仔细管理不同精度间的转换和数值范围