在训练网络的时候,发现输出全都是nan,这个时候很大可能是数值不稳定,比如除数太小,不稳定,或者是log里面的参数太小,不稳定,这个时候在可能出现运算不稳定的地方增加一些稳定系数就好了,比如:
1.在分母的位置增加一个稳定数
exp_st_sum = exp_st_with_target.sum(dim=1, keepdim=True) + 1e-6
exp_st_rate = exp_st_with_target / exp_st_sum
2.在log里增加稳定系数
loss_gt = F.nll_loss(torch.log(exp_gt_rate + 1e-6), labels)
这样就能解决出现的问题了