Pytorch定位NaN

本文介绍了在PyTorch中如何检测和处理训练过程中出现的NaN损失,包括正向传播异常检测、反向传播异常检测以及使用assert进行断言检查。此外,还讨论了NaN产生的常见原因,如梯度爆炸、非法计算和脏数据,并提供了检查数据、设置异常检测和添加断言的建议步骤。
摘要由CSDN通过智能技术生成

https://blog.csdn.net/mch2869253130/article/details/111034068

https://www.zzsblog.top/coding/2021/08/07/pytorch%E5%AE%9A%E4%BD%8DNaN.html

按照下面的流程来判断。
...

loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)

optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)

# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)
————————————————
版权声明:本文为CSDN博主「风吹草地现牛羊的马」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/mch2869253130/article/details/111034068

三板斧

检查NaN有三板斧, 尽管调试NaN通常需要一定的经验和耐心, 但记住这三个至少不至于手足无措.

#1 正向传播异常侦测

torch.autograd.set_detect_anomaly(True)

如题, forward时出现NaN即时报错. 尽管说得好听, 但有的时候并不能准确地定位问题所在. 属于调试NaN的必要辅助.

#2 反向传播异常侦测

# loss = model(X)
with torch.autograd.detect_anomaly():
    loss.backward()

如题, backward时出现NaN时即时报错. 相比#1来说更难确切定位问题, 往往用于兜底, 即确保出现NaN时程序会尽快抛出异常.

#3 assert

assert是确保程序行为正确的重要手段. 对于一个算法来说, 出现NaN不管怎么说都意味着不正常. 同时, 对debug来说, 最重要的就是找到事发现场, 而assert正是寻找真正现场的利器.

在pytorch中, 检查NaN的函数为torch.isnan(T). 于是我们可以构造如下断言:

assert not torch.any(torch.isnan(T))

当然, 这么写其实有一点性能浪费, 但写python, 又是debug专用代码, 何必考虑这么多呢¯\_(ツ)_/¯

将这个断言加在你认为有可能出现NaN的步骤之后. 这样一旦出现NaN, 你至少能抓住一个现场. 哪怕这个现场已经漂移, 配合调试器你也能更有逻辑地找到真正的事发现场.

NaN的可能原因

讲完三板斧总得讲讲NaN的成因, 要不然就是光有方法没有理论(x 尤其是#3, 要求调试者非常充分且熟练地掌握NaN的可能成因.

梯度爆炸

梯度爆炸, 或者梯度消失都可能导致NaN. 这个问题往往会被#2 反向传播异常检测捕获, 但真正定位到问题却难上加难. 相对来说, 重新推导一遍自己的理论模型、寻找可能导致梯度爆炸的计算显得更有针对性.

计算不合法

这也是NaN最常见的成因. 毕竟大多数的网络, 尤其是复现、组合别人的网络结构一般不会碰到梯度爆炸的问题, 而NaN大多出现于loss计算的部分, 诞生于某个小小的不合法计算, 然后污染它参与计算的所有结果, 最后在你的loss值上表现出来.

常见套路:

  • $ log(x), x \leq 0 $
  • $ c/0 $

尚有其他的一些情况我自己没遇到过, 网上可能会有补充

这种问题运气好的话会被#1 正向异常检测直接找到, 但通常是找到一个漂移了亿点点的位置. 推荐用#3 assert的办法, 尤其是 自己写了loss时, 在关键位置放几个assert守门, 总归是没错的.

注意, 绝大多数时候, inf也是不合常理的存在. 因此你可能也需要同时寻找inf:

assert not torch.any(torch.isnan(T) + torch.isinf(T))

脏数据

NaN的次常见成因. 顾名思义, 出现NaN仅仅是因为数据里含有NaN. 通常来说直接读图片不会出现NaN, 往往是大意地处理数据后会出现这种情况.

随便举个例子.

mask = mask / mask.max()
# serialize mask

这句话看起来没问题, 把uint8{0, 255}转成float32[0, 1]. 相信很多人都这么写过. 正常来说不会有任何问题, 直到我遇到了一张纯黑的mask :P

毕竟谁也不会想到有一张图没标注还给放数据集里了是吧. 但不管怎么说, 此时我们犯了”除零”的错误. 这个mask会变成携带NaN的脏数据输入模型, 并在计算loss时将loss结果污染. 如果程序没有及时终止, 在仅仅一次反向传播之后, 你的模型参数将变为NaN, 其一切推导将得出NaN ¯\_(ツ)_/¯

检查NaN的一般步骤

  1. 检查数据
  2. 开启正向和反向异常检测
  3. 给模型的直接输出结果和最终loss加assert
  4. 通过经验、猜测、反推等方法逐步把assert加到之前的步骤, 直到触发的assert帮你找到了不合法计算
  5. 若计算loss的过程中没有发现问题, 且总是触发反向传播异常, 那可以考虑从理论上检查梯度爆炸和梯度消失
  • 13
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值