torch.cuda.amp自动混合精度训练 —— 节省显存并加快推理速度
文章目录
1、什么是amp?
amp:Automatic mixed precision,自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。
自动混合精度的关键词有两个:自动、混合精度。这是由PyTorch 1.6的torch.cuda.amp
模块带来的:
from torch.cuda import amp
混合精度预示着有不止一种精度的Tensor,那在PyTorch的AMP模块里是几种呢?2种:torch.FloatTensor
(浮点型 32位)和torch.HalfTensor
(半精度浮点型 16位);
自动预示着Tensor的dtype类型会自动变化,也就是框架按需自动调整tensor的dtype(其实不是完全自动,有些地方还是需要手工干预);
注意
torch.cuda.amp
的名字意味着这个功能只能在cuda上使用。- torch默认的tensor精度类型是
torch.FloatTensor
2、为什么需要自动混合精度(amp)?
也可以这么问:为什么需要自动混合精度,也就是torch.FloatTensor
和torch.HalfTensor
的混合,而不全是torch.FloatTensor
?或者全是torch.HalfTensor
?
原因: 在某些上下文中torch.FloatTensor
有优势,在某些上下文中torch.HalfTenso
r有优势。
torch.HalfTensor
torch.HalfTenso
r的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快;torch.HalfTensor
的劣势就是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)。
可见,当有优势的时候就用torch.HalfTensor
,而为了消除torch.HalfTensor
的劣势,我们带来了两种解决方案:
- 梯度scale,这正是上一小节中提到的
torch.cuda.amp.GradScaler
,通过放大loss的值来防止梯度消失underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去) - 回落到
torch.FloatTensor
,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor
,什么时候用半精度浮点型呢?这是PyTorch框架决定的,AMP上下文中,一些常用的操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor
(如:conv1d、conv2d、conv3d、linear、prelu等)
3、如何在PyTorch中使用自动混合精度?
答案是 autocast + GradScaler
3.1 autocast
使用torch.cuda.amp
模块中的autocast 类。
from torch.cuda import amp
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters()