调试模式下运行代码,并按以下代码设置torch配置,
torch.autograd.set_detect_anomaly(True)
with torch.autograd.detect_anomaly():
loss = loss_func()
发生梯度爆炸时,torch会显示存在梯度爆炸的代码,如
调试模式下运行代码,并按以下代码设置torch配置,
torch.autograd.set_detect_anomaly(True)
with torch.autograd.detect_anomaly():
loss = loss_func()
发生梯度爆炸时,torch会显示存在梯度爆炸的代码,如