自动混合精度(AMP)方法autocast和GradScaler

AMP介绍

之前大多数的学习框架都是用float32的精度进行训练,AMP做的改进就是使用float32和float16相结合进行训练,并且相同的超参数实现了与使用float32几乎相同的精度

为什么使用混合精度

混合精度预示着有不止一种精度的Tensor,PyTorch的AMP有2种精度是torch.FloatTensor和torch.HalfTensor

因为使用自动混合精度其实一种在训练过程中对训练速度和显存的优化,众所周知神经网络模型的训练通常会大量占据显存,训练时间也特别长,因此做出优化是很有必要的。

调用方法是torch.cuda.amp.autocast()和torch.cuda.amp.GradScaler() 字面意思看只能在cuda上使用,事实上,这个功能正是NVIDIA的开发人员贡献到PyTorch项目中的。后面会用代码详细介绍这两种方法的使用

代码

scaler = torch.cuda.amp.GradScaler()

def train_one_epoch(loader, model, loss_fn, optimizer, scaler):
    for batch_idx, (data, targets) in enumerate(loop):
     /
     .......
     /
     // scaler的大小在每次迭代中动态估计,为了尽可能减少梯度,scaler应该更大;但太大,半精度浮点型又容易overflow(变成inf或NaN).所以,动态估计原理就是在不出现if或NaN梯度的情况下,尽可能的增大scaler值。在每次scaler.step(optimizer)中,都会检查是否有inf或NaN的梯度出现:
        with torch.cuda.amp.autocast():
            scores = model(data)
            loss = loss_fn(scores, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward() //反向传播梯度放大
        scaler.step(optimizer)//首先把梯度值unscale回来,如果梯度值不是inf或NaN,则调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新。
        scaler.update()//看是否要增大scaler

torch.HalfTensor

torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用,同时训练速度更快。

torch.HalfTensor的劣势就是溢出错误,在开启 autocast 后,网络中间结果的数值类型会变成 float16,其对应的梯度自然也是 float16。而由于 float16 的数值范围比 float32 要小,所以如果遇到特别小的数(比如 loss、gradient),float16 就难以精确表达,这时候一般需要进行梯度放缩(Gradient Scaling)。做法是对网络的 loss 进行放大,从而使反向传播时网络中间结果对应的梯度也得到相同的放大,减少精度的损失,而梯度反传到参数时,仍会是 float32 的类型,在等比缩小之后,并不会影响参数的更新。

GradScaler

1如果出现inf或NaN, GradScaler会忽略此次权重更新,并将scaler的大小缩小

2如果没有出现inf或NaN,那么权重正常更新,并且当连续多次没有出现inf或NaN,则会将scaler的大小增加

总结

autocast 和 GradScaler 一般配合使用,起到作用一般就是减少显存占用,加快模型训练速度

  • 7
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

dzm1204

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值