报错:
- RuntimeError: Function ‘MulBackward0’ returned nan values in its 0th output.
- RuntimeError: Function ‘ExpBackward0’ returned nan values in its 0th output.
- 以及其他的nan错误
以此代码为例:
out = model(input)
pred = out ** (1/4)
loss = 2*nn.MSELoss()(pred ,y)+0.1*nn.L1Loss()(pred ,y)
loss.backward()
optimizer.step()
怎么定位bug位置?
在loss.backward()
后加入如下代码
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print("nan gradient found")
print("name:",name)
print("param:",param.grad)
raise SystemExit
打印出反向传播的梯度为NaN的模块名称和参数.
1.找到与之相关的Loss函数,
分别令这些Loss为零
out = model(input)
pred = out ** (1/4)
loss =