首先通过以下代码,对问题进行定位
with torch.autograd.detect_anomaly():
loss.backward()
然后,发现问题出在损失函数上面了:
RuntimeError: Function 'CtcLossBackward' returned nan values in its 0th output.
检查CTC Loss的参数设置,由于我没有修改原始config中图像初始长宽,而我自己使用的数据集label都比较长,这导致CTC Loss中length比preds_size要更长了。
我对应修改了config中图像的宽度,这样preds_size也会对应变大,问题解决。
类似的问题其实在工程对应gitlab上也有人提出:
在找解决方法的过程中,也尝试了很多其他方法,一并总结在这里。
1. 调小学习速率。
把lr调整至0发现问题依然存在,排除lr的问题。
2. 梯度裁剪
torch.nn.utils.clip_grad_norm(model.parameters(),1.0)
3. 另外,之前也有遇到过类似的情况。当时是输出logit中包含0,经过log(0)后,出现nan。可以对logits加上一个小数1e-6,解决问题