主要检查两方面:
(1)数据处理过程中是否使input或者label中出现Nan值
if len(np.unique(np.isnan(input))) > 1:
print(name, np.unique(np.isnan(input)))
(2)自定义loss中存在除数为0或者开根号的数为0等情况,一般加一个极小数就可解决
使用自定义loss时,训练几代后就出现问题,loss计算中用了torch.sqrt()来开根号,在网络训练初期是没有什么问题的loss也都正常下降,但是训练到一半会出现NAN。
loss本身计算时不出现Nan,但是网络输出为Nan和-inf值
参考博客,初步认为是开根号中值为0的情况,修改代码:
e = 1e-6
torch.sqrt(a + e)
改完后,仍然存在Nan的问题
进一步检查后,发现训练数据的label存在Nan值
如上图,就是一个 label 例子,我对上图 label 用 reScale() 函数转换后得到的新 label 出现NaN值
def reScale(data):
return 0.1 + (data / (np.max(data)-np.min(data)))*0.9
##改代码为
def reScale(data):
_range = np.max(data) - np.min(data) + 1e-5
return 0.1 + ((data - np.min(data)) / _range)*0.9
修改 reScale 函数(主要是要确保 _range 不为零后问题解决。
发现我碰到的出现NaN值得情况基本上就是除0了。