训练深度学习网络的过程中出现 loss nan总是让人觉得头疼,本人这次是在pytorch的半精度amp.autocast, amp.GradScaler训练时候出现了loss nan。
loss nan 常见的一般几种情况有:
lr学习率设置的过大,导致loss nan,调小lr;
训练数据中有脏数据,也会导致loss,这个需要自己排查,batch设置为1,遍历所有训练数据遇到nan就打印数据路径再退出;
if np.isnan(loss):
sys.exit()
网络计算过程中可能存在nan,但这种可能比较少见。等等;
计算loss时候出现nan,特别是众多交叉熵损失中,核心原因应该是log(0)导致的。
笔者是在pytorch的半精度amp.autocast, amp.GradScale训练时候出现了loss nan,而且Lr设置合理,且没有脏数据,想到应该是半精度把一些很小的数表示为0了,计算loss时候把输出fp16—> fp32,问题解决。
out = out.float()
下面用交叉熵损失验证了一下fp16,fp64的结果
import numpy as np
import os
out = np.array([0.00000001]).astype(np.float16)
lab = np.array([0]).astype(np.float16)
loss = lab * np.log(out) - (1-lab) * np.log(1-out)
print(loss)
# [nan]
out = np.array([0.00000001]).astype(np.float64)
lab = np.array([0]).astype(np.float64)
loss = lab * np.log(out) - (1-lab) * np.log(1-out)
print(loss)
#[1.00000001e-08]
后面的思考,因为已经用了amp.autocast, amp.GradScale为啥还会在训练到一半的时候出现这个问题呢
scaler = GradScaler()
with autocast():
out = model(inputs)
out = out.float()
loss = criterion(out, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
应该是再回传中如果计算loss 时,出现了除以0的情况以及loss过大,被半精度判断为inf这种情况会被捕捉到,但是因为半精度的原因导致网络的输出变为nan,这时scaler.scale(loss).backward()没法捕获,因为回传的梯度并不是nan,这时候scaler.step(optimizer)也没法处理,因为已经是nan再加一个极小的eps仍然是nan,所以直接在loss计算前out = out.float()。
后续更新:
训练后又出现loss nan了,查了一下发现计算loss前,网络的输出层已经全部是nan了,输出层fp32也不好使了, 直接去掉混合精度训练一了百了,宁愿慢一点,也不愿loss nan。