混合精度训练,使用半精度加速训练
为什么要选择混合精度
深度学习在训练阶段往往操作一系列浮点值进行运算,而受制于显卡的有限计算单元,计算不同规格的浮点数的效率截然不同(例如计算一个32位浮点数的单元可以拆分为两个计算16位浮点数的单元),因此选择更低精度的浮点数将会带来巨大的效率提升。
下图为A100在不同进度下的性能:
在训练的过程中,中间参数的精度需求实际并不高,我们可以牺牲一部分精度来换取更快的学习速度和更少的显存占用。
尽管使用FP16(或TF16)可以带来显著的效率和内存优势,但它也面临一些挑战,如精度溢出和舍入误差。为了解决这些问题,通常在前向和反向传播过程中使用FP16,而在累积梯度和更新模型参数时使用FP32(Ampere架构后为TF32),以保证数值稳定性和精度。这种做法在保持计算精度的同时,也提高了计算效率和减少了内存占用。
混合精度训练
Pytorch 中,autocast 可以十分方便的开启混合精度计算。
为防止下溢或溢出,还需使用 GradScaler 对梯度进行适当缩放来适应半精度浮点数的范围。
- 导入GradScaler 和 autocast:
from torch