主要检查两方面:
(1)数据处理过程中是否使input或者label中出现Nan值
if len(np.unique(np.isnan(input))) > 1:
print(name, np.unique(np.isnan(input)))
(2)自定义loss中存在除数为0或者开根号的数为0等情况,一般加一个极小数就可解决
使用自定义loss时,训练几代后就出现问题,loss计算中用了torch.sqrt()来开根号,在网络训练初期是没有什么问题的loss也都正常下降,但是训练到一半会出现NAN。
loss本身计算时不出现Nan,但是网络输出为Nan和-inf值
参考博客,初步认为是开根号中值为0的情况,修改代码:
e = 1e-6
torch.sqrt(a + e)
改完后,仍然存在Nan的问题
进一步检查后,发现训练数据的label存在Nan值