gradient 为nan可能原因:
- 梯度爆炸
- 学习率太大
- 数据本身有问题
- backward时,某些方法造成0在分母上, 如:使用方法sqrt()
定位造成nan的代码:
import torch
# 异常检测开启
torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值,定位code
with torch.autograd.detect_anomaly():
loss.backward()
gradient 为nan可能原因:
定位造成nan的代码:
import torch
# 异常检测开启
torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值,定位code
with torch.autograd.detect_anomaly():
loss.backward()