神经网络得到 NaN 的排查流程

排查深度学习模型中NaN错误的步骤
文章介绍了在深度学习模型中遇到NaN值时的排查方法,包括检查数据异常,确认损失函数输入是否正确,检查计算过程中的错误,以及在反向传播阶段可能出现的梯度爆炸问题。建议通过设置断言、使用torch.autograd检测异常和进行梯度截断来解决NaN问题。

太长不看版:

检查输入 -> 检查损失计算过程 -> 检查反向传播

详细版:

分析原因

首先分析产生 NaN 的原因:

  1. 数据异常:从原始的数据就有问题,整个网络里都有 NaN

  1. 损失函数输入异常:数据不符合损失函数的输入要求(格式、范围、精度等)

  1. 计算过程有误:0 作为除数、≤0作为对数输入、计算过程中出现对 Inf 值的复杂操作, 计算数值超过当前精度的上限/下限

  1. 梯度爆炸:常见于具有 rnn 结构的网络,梯度累计过大就爆炸

逐步排查

一、首先考虑数据异常

确定输入中没有 NaN,可以在输入后增加一个判断:

if numpy.any(numpy.isnan(input_data)):
    print('Input data has NaN!')
    exit()

因为这个判断并不影响实际的运行,所以可以先在输入后加上。

也可以通过 torch 自带故障检测模块来查找,但是定位并不精确:

torch.autograd.set_detect_anomaly(True)

二、排查出现 NaN 的位置

通过如下的断言判断可以精确查找在运行过程中 NaN 出现的位置

assert not torch.any(torch.isnan(T))

每个模块的输出都可以加以检查,tips: 二分法最快: )

这一步发现问题后,可能是出现了原因分析中的第2、3条:【损失函数输入异常】和【计算过程有误】

从上面两个方向加以排查,应该就能有所收获。

【损失函数输入异常】常出现于自定义的损失函数,需要考虑其中的计算bug、越界、数值稳定等问题。

【计算过程有误】常见于与 0, Inf 相关的操作,以及 log 与非正数(太小也可能数值不稳定)的组合。此外,还需要考虑计算过程中数值精度是否够用,避免上溢出或下溢出成为 NaN

三、排查反向传播

可以使用如下语句检查反向传播:

with torch.autograd.detect_anomaly():
    loss.backward()

反向传播中产生 NaN 多数是由于 rnn 等结构产生的梯度爆炸。

对 loss 加入梯度截断(gradient clipping)或者降低网络的学习率(learning rate)都是解决的办法

这里还有个给萌新的福利,【反向传播前】或者【梯度更新后】要记得把梯度归零,即

xxx.optimizer.zero_grad()

不然算着算着就会爆炸: )

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值