https://zhuanlan.zhihu.com/p/165152789 自动混合精度
为什么需要自动混合精度?
这个问题其实暗含着这样的意思:为什么需要自动混合精度,也就是torch.FloatTensor和torch.HalfTensor的混合,而不全是torch.FloatTensor?或者全是torch.HalfTensor?
如果非要以这种方式问,那么答案只能是,在某些上下文中torch.FloatTensor有优势,在某些上下文中torch.HalfTensor有优势呗。答案进一步可以转化为,相比于之前的默认的torch.FloatTensor,torch.HalfTensor有时具有优势,有时劣势不可忽视。
torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快;
torch.HalfTensor的劣势就是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)。
可见,当有优势的时候就用torch.HalfTensor,而为了消除torch.HalfTensor的劣势,我们带来了两种解决方案:
1,梯度scale,这正是上一小节中提到的torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去);
2,回落到torch.FloatTensor,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor,什么时候用半精度浮点型呢?这是PyTorch框架决定的,在PyTorch 1.6的AMP上下文中,如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:
- __matmul__
- addbmm
- addmm
- addmv
- addr
- baddbmm
- bmm
- chain_matmul
- conv1d
- conv2d
- conv3d
- conv_transpose1d
- conv_transpose2d
- conv_transpose3d
- linear
- matmul
- mm
- mv
- prelu
如何在PyTorch中使用自动混合精度?
答案就是autocast + GradScaler。
from torch.cuda.amp import autocast as autocast
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 前向过程(model + loss)开启 autocast
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. 为了梯度放大.
scaler.scale(loss).backward()
# scaler.step() 首先把梯度的值unscale回来.
# 如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,
# 否则,忽略step调用,从而保证权重不更新(不被破坏)
scaler.step(optimizer)
# 准备着,看是否要增大scaler
scaler.update()
scaler的大小在每次迭代中动态的估计,为了尽可能的减少梯度underflow,scaler应该更大;但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。所以动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值——在每次scaler.step(optimizer)中,都会检查是否又inf或NaN的梯度出现:
1,如果出现了inf或者NaN,scaler.step(optimizer)会忽略此次的权重更新(optimizer.step() ),并且将scaler的大小缩小(乘上backoff_factor);
2,如果没有出现inf或者NaN,那么权重正常更新,并且当连续多次(growth_interval指定)没有出现inf或者NaN,则scaler.update()会将scaler的大小增加(乘上growth_factor)。