在训练过程中,在确保数据没有异常的情况。由于自定义loss中出现了除数为0或对数为0的情况,导致无法计算得到数字就会得到NAN,然后loss.backward()就会导致整个网络的权重数值都变成NAN。直接导致网络无法计算。
所以在网络训练过程中需要对NAN进行检测和处理。
NAN检测
如果只是一个简单的标量,直接使用isnan进行判据
torch.isnan(loss)
如果只是一个相对复杂的矢量,则需要使用结合.int().sum()对nan进行计数,判据大于0
torch.isnan(loss).int().sum()