1 出现nan的理论分析
从本质上来说,“出现nan”现象主要是因为数值超出当前数据类型的表示范围,其含义是指 not a number ,常在浮点数运算中出现;
目前知道nan的出现由以下四种来源:
inf-inf | inf/inf | 0*inf | 0/0
2 可能引起nan的原因
2.1 学习率过大,出现梯度爆炸,从而导致loss过大,使得数值溢出;
2.2 在运算过程中,由于出现“除0”运算,导致出现nan
2.3 在使用AdamW优化器时,开启torch.amp
混合精度运算,出现nan
这是因为optim.AdamW
算法存在如下除法运算:
θ
t
←
θ
t
−
γ
m
t
^
/
(
v
t
^
+
ϵ
)
\theta_{t} \leftarrow \theta_{t}-\gamma \widehat{m_{t}} /\left(\sqrt{\widehat{v_{t}}}+\epsilon\right)
θt←θt−γmt
/(vt
+ϵ)
torch
中提供了epsilon参数(默认值eps=1e-08
)来防止除0,但是在半精度下近似成了0,所以可能会导致nan;
3 调试方法
判断loss值是否出现nan:
torch.isnan(loss)
Paddle论文复现教程中快速复现NaN的技巧:
保存出现NaN前的模型权重及输入,保证能够在短时间内复现NaN的问题,然后做模型前向对齐;