1. 加自动监测函数找异常值
import torch
# 正向传播时:开启自动求导的异常侦测
torch.autograd.set_detect_anomaly(True)
# 反向传播时:在求导时开启侦测
with torch.autograd.detect_anomaly():
loss.backward()
2. 自定义的loss注意除法时分母不是0
注意分母加eps保证计算稳定性
3. log(0), sqrt(0) 会导致nan,0 * inf也会变成nan
判断nan不能用 == 或 is, 要用numpy.isnan() / torch.isnan()
4. 脏数据
有可能只有某些input会导致nan,一开始就先shuffle=False保证每次读取数据的顺序是一样的,方便定位
5. fp16
曾经gradient & loss上来就是nan,求助万票,一条评论问我有没有用fp16, 当时还不理解这跟半精度有什么关系。后来发现bug还真的跟半精度有关。args.fp16没有控制到一层LayerNorm的初始化,导致我明明args.fp16=False,仍然用了半精度的FusedLayerNorm,然后当时我的batch_size又很小,所以假如batch内的divergence不够的话,有可能两个差异很小的数半精度了之后差异就变成0了,那如果某个地方需要divide by variance就直接nan了。。。