在训练过程中,在确保数据没有异常的情况。由于自定义loss中出现了除数为0或对数为0的情况,导致无法计算得到数字就会得到NAN,然后loss.backward()就会导致整个网络的权重数值都变成NAN。直接导致网络无法计算。
所以在网络训练过程中需要对NAN进行检测和处理。
NAN检测
如果只是一个简单的标量,直接使用isnan进行判据
torch.isnan(loss)
如果只是一个相对复杂的矢量,则需要使用结合.int().sum()对nan进行计数,判据大于0
torch.isnan(loss).int().sum()
numpy类型的计数
np.isnan(frame_fix).sum()
处理
如果存在NAN就需要处理掉这个数,一般可以把赋值为一个常数,或者剔除掉
if torch.isnan(loss):
loss=Constant
使用np.nanmean等方法,直接得到均值,最大值或最小值
np.nanmean(loss)
np.nanmax(loss)
np.nanmin(loss)
由于没有查到直接的删除nan的API,因此通过np.isnan,np.where, np.delete三个关键词剔除掉nan数据
if np.isnan(frame[:,0]).sum(): #矩阵中存在nan
nan_row = np.isnan(frame[:,0]) #找到对应的行
row = np.where(nan_row==True) #行号
frame_fix = np.delete(frame, row, axis=0) #删除行