解决PyTorch半精度(AMP)训练nan问题

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达、

作者:可可哒  |  已授权转载(源:知乎)

https://zhuanlan.zhihu.com/p/443166496

本文主要是收集了一些在使用pytorch自带的amp下loss nan的情况及对应处理方案。

Why?

如果要解决问题,首先就要明确原因:为什么全精度训练时不会nan,但是半精度就开始nan?这其实分了三种情况:

  1. 计算loss 时,出现了除以0的情况

  2. loss过大,被半精度判断为inf

  3. 网络参数中有nan,那么运算结果也会输出nan

1&2我想放到后面讨论,因为其实大部分报nan都是第三种情况。这里来先看看3。什么情况下会出现情况3?这个讨论给出了不错的解释:

Nan Loss with torch.cuda.amp and CrossEntropyLoss

https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

给大家翻译翻译:在使用ce loss 或者 bceloss的时候,会有log的操作,在半精度情况下,一些非常小的数值会被直接舍入到0,log(0)等于啥?——等于nan啊!

于是逻辑就理通了:回传的梯度因为log而变为nan->网络参数nan-> 每轮输出都变成nan。(;´Д`)

How?

问题定义清楚,那解决方案就非常简单了,只需要在涉及到log计算时,把输入从half精度转回float32:

x = x.float()
x_sigmoid = torch.sigmoid(x)

一些思考&废话

这里我接着讨论下我第一次看到nan之后,企图直接copy别人的解决方案,但解决不掉时踩过的坑。比如:

  1. 修改优化器的eps

有些blog会建议你从默认的1e-8 改为 1e-3,比如这篇:pytorch1.1 半精度训练 Adam RMSprop 优化器 Nan 问题

https://blog.csdn.net/gwb281386172/article/details/104705195

经过上面的分析,我们就能知道为什么这种方法不行——这个方案是针对优化器的数值稳定性做的修改,而loss计算这一步在优化器之前,如果loss直接nan,优化器的eps是救不回来的(托腮)。

那么这个方案在哪些场景下有效?——在loss输出不是nan时(感觉说了一句废话)。optimizer的eps是保证在进行除法backwards时,分母不出现0时需要加上的微小量。在半精度情况下,分母加上1e-8就仿佛听君一席话,因此,需要把eps调大一点。

2. 聊聊amp的GradScaler

GradScaler是autocast的好伙伴,在官方教程上就和autocast配套使用:

from torch.cuda.amp import autocast, GradScaler
...
scaler = GradScaler()
for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        scaler.step(optimizer)
        scaler.update()

具体原理不是我这篇文章讨论的范围,网上很多教程都说得很清楚了,比如这个就不错:

Gemfield:PyTorch的自动混合精度(AMP)

https://zhuanlan.zhihu.com/p/165152789

但是我这里想讨论另一点:scaler.step(optimizer)的运行原理。

在初始化GradScaler的时候,有一个参数enabled,值默认为True。如果为True,那么在调用scaler方法时会做梯度缩放来调整loss,以防半精度状况下,梯度值过大或者过小从而被nan或者inf。而且,它还会判断本轮loss是否是nan,如果是,那么本轮计算的梯度不会回传,同时,当前的scale系数乘上backoff_factor,缩减scale的大小

那么,为什么这一步已经判断了loss是不是nan,还是会出现网络损失持续nan的情况呢?

这时我们就得再往前思考一步了:为什么loss会变成nan?回到文章一开始说的:

(1)计算loss 时,出现了除以0的情况;

(2)loss过大,被半精度判断为inf;

(3)网络直接输出了nan。

(1)&(2),其实是可以通过scaler.step(optimizer)解决的,分别由optimizer和scaler帮我们捕捉到了nan的异常。但(3)不行,(3)意味着部分甚至全部的网络参数已经变成nan了。这可能是在更之前的梯度回传过程中除以0导致的——首先【回传的梯度不是nan】,所以scaler不会捕捉异常;其次,由于使用了半精度,optimizer接收到了【已经因为精度损失而变为nan的loss】,nan不管加上多大的eps,都还是nan,所以optimizer也无法处理异常,最终导致网络参数nan。

所以3,只能通过本文一开始提出的方案来解决。其实,大部分分类问题在使用半精度时出现nan的情况都是第3种情况,也只能通过把精度转回为float32,或者在计算log时加上微小量来避免(但这样会损失精度)。

参考

Nan Loss with torch.cuda.amp and CrossEntropyLoss

ICCV和CVPR 2021论文和代码下载

后台回复:CVPR2021,即可下载CVPR 2021论文和代码开源的论文合集

后台回复:ICCV2021,即可下载ICCV 2021论文和代码开源的论文合集

后台回复:Transformer综述,即可下载最新的3篇Transformer综述PDF
CVer-PyTorch交流群成立
扫码添加CVer助手,可申请加入CVer-PyTorch 微信交流群,方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch和TensorFlow等群。
一定要备注:研究方向+地点+学校/公司+昵称(如PyTorch+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

▲长按加小助手微信,进交流群
▲点击上方卡片,关注CVer公众号

整理不易,请点赞和在看66437797c32d491a36dc06115735fe4a.gif

  • 17
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
根据你提供的代码,报错信息显示 index 的形状是 [112, 1],而 x 的形状是 [16, 7],因此 index 与 x 的形状不匹配,导致了这个错误。 为了解决这个问题,你可以尝试将 LDAMLoss 类的 forward 函数的 index 的形状改为 [x.size(0), x.size(1)],这样就能够与 x 的形状匹配。同时,由于 target.data 已经被弃用,建议改为 target。修改后的代码如下: ``` class LDAMLoss(nn.Module): def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): super(LDAMLoss, self).__init__() m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) m_list = m_list * (max_m / np.max(m_list)) m_list = torch.cuda.FloatTensor(m_list) self.m_list = m_list assert s > 0 self.s = s self.weight = weight def forward(self, x, target): index = torch.zeros_like(x, dtype=torch.uint8) index.scatter_(1, target.view(-1, 1).type(torch.int64), 1) index_float = index.type(torch.cuda.FloatTensor) batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) batch_m = batch_m.view((-1, 1)) x_m = x - batch_m output = torch.where(index, x_m, x) return F.cross_entropy(self.s*output, target, weight=self.weight) ``` 另外,你的 init 方法也存在一些错误。super 函数应该在 init 方法的第一行调用,而不是在函数体内。因此,init 方法应该改为: ``` def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): super(LDAMLoss, self).__init__() m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) m_list = m_list * (max_m / np.max(m_list)) m_list = torch.cuda.FloatTensor(m_list) self.m_list = m_list assert s > 0 self.s = s self.weight = weight ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值