定位造成梯度为nan的代码:
import torch
# 异常检测开启
torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值,定位code
with torch.autograd.detect_anomaly():
loss.backward()
仅此来记录遇到的问题,快速有力的解决办法,来自https://blog.csdn.net/sini2018/article/details/112088749
import torch
# 异常检测开启
torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值,定位code
with torch.autograd.detect_anomaly():
loss.backward()
仅此来记录遇到的问题,快速有力的解决办法,来自https://blog.csdn.net/sini2018/article/details/112088749