nan排查
最近调试代码时,发现一个loss
全部变为nan
。网上主流的解释大多千篇一律,比如
1.学习率太高。
2.loss函数
3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决
4.数据本身,是否存在Nan,可以用numpy.any(numpy.isnan(x))检查一下input和target
...
我这儿问题可以确定,是由于数学运算造成的nan。最后具体定位到,是由于一个池化操作造成的。在pycharm中单步调试,使用如下表达式判断x中是否有nan,因为nan!=nan
结果是True
,而其他任何实数!=实数本身
结果均为False
:
torch.sum(x!=x)
torch.sum(x[:1000] != x[:1000]) # 依次这样二分法查找第一次出现nan的那个位置
最后发现是对负数开方将会产生nan
,如下
p = 3.
x_i = (torch.mean(x_i**p, dim=-1) + 1e-12)**(1/p) # 如果x_i中有负数,那么开方之后,对应位置就会为nan
解释
主流的计算器,编程语言均是不支持负数的开方的,即便是数学上成立,计算器还是会给出虚数解, 对于pytorch来说,就成了nan。比如:
>>>(-27)**(1/3)
(1.5000000000000004+2.598076211353316j)
-27的1/3次是存在实数解的,但是计算结果还是一个虚数解。所以以后要开方的话,一定要对负数特别处理,比如把负号拿出来:
>>>-(1)**(1/3)
-1.0
>>>-(27)**(1/3)
-3.0
这篇博文也提到了python中这个问题,实际上一般的计算器也是这样的。