太长不看版:
检查输入 -> 检查损失计算过程 -> 检查反向传播
详细版:
分析原因
首先分析产生 NaN 的原因:
数据异常:从原始的数据就有问题,整个网络里都有 NaN
损失函数输入异常:数据不符合损失函数的输入要求(格式、范围、精度等)
计算过程有误:0 作为除数、≤0作为对数输入、计算过程中出现对 Inf 值的复杂操作, 计算数值超过当前精度的上限/下限
梯度爆炸:常见于具有 rnn 结构的网络,梯度累计过大就爆炸
逐步排查
一、首先考虑数据异常
确定输入中没有 NaN,可以在输入后增加一个判断:
if numpy.any(numpy.isnan(input_data)):
print('Input data has NaN!')
exit()因为这个判断并不影响实际的运行,所以可以先在输入后加上。
也可以通过 torch 自带故障检测模块来查找,但是定位并不精确:
torch.autograd.set_detect_anomaly(True)二、排查出现 NaN 的位置
通过如下的断言判断可以精确查找在运行过程中 NaN 出现的位置
assert not torch.any(torch.isnan(T))每个模块的输出都可以加以检查,tips: 二分法最快: )
这一步发现问题后,可能是出现了原因分析中的第2、3条:【损失函数输入异常】和【计算过程有误】
从上面两个方向加以排查,应该就能有所收获。
【损失函数输入异常】常出现于自定义的损失函数,需要考虑其中的计算bug、越界、数值稳定等问题。
【计算过程有误】常见于与 0, Inf 相关的操作,以及 log 与非正数(太小也可能数值不稳定)的组合。此外,还需要考虑计算过程中数值精度是否够用,避免上溢出或下溢出成为 NaN
三、排查反向传播
可以使用如下语句检查反向传播:
with torch.autograd.detect_anomaly():
loss.backward()反向传播中产生 NaN 多数是由于 rnn 等结构产生的梯度爆炸。
对 loss 加入梯度截断(gradient clipping)或者降低网络的学习率(learning rate)都是解决的办法
这里还有个给萌新的福利,【反向传播前】或者【梯度更新后】要记得把梯度归零,即
xxx.optimizer.zero_grad()不然算着算着就会爆炸: )
排查深度学习模型中NaN错误的步骤
文章介绍了在深度学习模型中遇到NaN值时的排查方法,包括检查数据异常,确认损失函数输入是否正确,检查计算过程中的错误,以及在反向传播阶段可能出现的梯度爆炸问题。建议通过设置断言、使用torch.autograd检测异常和进行梯度截断来解决NaN问题。
1万+

被折叠的 条评论
为什么被折叠?



