Pytorch 混合精度训练(Automatic Mixed Precision)原理解析
1. Overview
默认情况下,大多数深度学习框架(比如 pytorch)都采用 32 位浮点算法进行训练。Automatic Mixed Precision(AMP, 自动混合精度)可以在神经网络训练过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。
Pytorch AMP 是从 1.6.0 版本开始的,在此之前借助 NVIDIA 的 apex 可以实现 amp 功能。Pytorch 的 AMP 其实是从 apex 简化而来的,和 apex 的 O1 相当。
AMP 里面的 Mixed 的方式很多,但是这里仅仅讨论 Fp16 和 Fp32 的混合。另外 pytorch 支持 cpu gpu 等不同设备上的 AMP,这里仅仅讨论 GPU 的 AMP(也就是 torch.cuda.amp)。
AMP 的使用非常简单,这里重点介绍 AMP 的原理。如果你只是想知道怎么使用 AMP 可以移步 官方文档。
2. 原理
2.1 Fp16 V.S. Fp32
通过上面的对比不难发现,Fp16 相对于 Fp32 主要有以下优势:
- 对于 memory-limited 算子, Fp16 相对于 Fp32 可以减小一半的访存,能提升算子性能
- 减小模型显存占用,相同条件下,可以使用更大的 batch size,或者是更复杂的模型
- 对于计算密集的算子,比如 linear, conv 等通过 Tensor Cores 可以进行加速
而劣势主要在于:
- 表示的动态区间更小
- 精度不如 Fp32