使用fp16的时候,容易报上面的错误
解决方法:
(1)检查自己的代码实现,数组是否越界
BCELoss之前有没有转到0~1之间
(2)
这个通常是产生了nan导致数组越界,可以通过如下方式定位nan出现的位置:
with torch.autograd.detect_anomaly():
loss.backward()
一般来说是分母为0或者exp的值过大导致的
来自mmdetection的isue下大佬的点评,方便debug
(3)
实现上提高数值稳定性
容易出现0**0的情况那么就在底数上面加一个1e-6就可以解决问题