该怎么办?
最近一直在挣扎这个问题,5月份写的网络就是复现不了了,如今总算解决
首先nan在计算机中的含义是非数,即未定义或不可表示的数。而loss中出现NAN有多种原因:
- 输入数据和输出数据存在脏数据,可用下面程序检查
if torch.any(torch.isnan(output)):
break
-
自己设计的损失函数可能存在问题,检查能否正常反向传播,并对输入的数据保持同一个类型
-
学习率太大,减小学习率
-
设置梯度截断
nn.utils.clip_grad_value_(model.parameters(), clip_value=5)
-
如果使用sqrt(),log(),除0操作要特别注意
-
直接定位出现NAN的位置
torch.autograd.set_detect_anomaly(True)
with torch.autograd.detect_anomaly():
loss.backward()
RuntimeError: Function ‘MulBackward0’ returned nan values in its 0th output.