按照如下设置异常侦测,在出现NaN异常时程序会报错,可直接定位错误代码:
import torch
# 正向传播时:开启自动求导的异常侦测
torch.autograd.set_detect_anomaly(True)
# 反向传播时:在求导时开启侦测
with torch.autograd.detect_anomaly():
loss.backward()
参考资料:pytorch_梯度出现NaN
按照如下设置异常侦测,在出现NaN异常时程序会报错,可直接定位错误代码:
import torch
# 正向传播时:开启自动求导的异常侦测
torch.autograd.set_detect_anomaly(True)
# 反向传播时:在求导时开启侦测
with torch.autograd.detect_anomaly():
loss.backward()
参考资料:pytorch_梯度出现NaN