import torch
# 正向传播
torch.autograd.set_detect_anomaly(True)
# TODO
# 反向传播
with torch.autograd.detect_anomaly():
loss.backward()
运行后会输出nan值得代码位置及原因
import torch
# 正向传播
torch.autograd.set_detect_anomaly(True)
# TODO
# 反向传播
with torch.autograd.detect_anomaly():
loss.backward()
运行后会输出nan值得代码位置及原因