1 amp 模块的作用
amp : 全称为 Automatic mixed precision,自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的
自动预示着Tensor的dtype类型会自动变化,也就是框架按需自动调整tensor的dtype
混合精度预示着有不止一种精度的Tensor :
torch.FloatTensor(浮点型 32位)(torch默认的tensor精度类型是torch.FloatTensor)
torch.HalfTensor(半精度浮点型 16位)
2 使用自动混合精度 (amp) 的原因
torch.HalfTensor:
torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快
torch.HalfTensor的劣势就是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)
3 解决方案
当有优势的时候就用torch.HalfTensor,而为了消除torch.HalfTensor的劣势,有两种解决方案:
梯度scale,这正是上一小节中提到的torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度消失underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去);
回落到torch.FloatTensor,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor,什么时候用半精度浮点型呢?这是PyTorch框架决定的,AMP上下文中,一些常用的操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor(如:conv1d、conv2d、conv3d、linear、prelu等)
4 GradScaler()
在训练最开始之前使用amp.GradScaler实例化一个GradScaler对象
# Initialize the gradient scaler
scaler = amp.GradScaler()