1. 问题描述
代码如下:
norm_ff = ff / (ff**2).sum(0, keepdim=True).sqrt()
coef_mat = torch.mm(norm_ff.t(), norm_ff)
coef_mat.div_(self.tau2)
L_fd = F.cross_entropy(coef_mat, y)
这是因为计算图中存在.sqrt(),这样会导致在第一个iteration之后出现nan,第一次iteration之内,还是可以看到loss不为nan的。
2. 解决方法
2.1 不开方,因为开方的求导会出现在分母上,因此需要避免分母为0!
2.2 norm_ff = ff / ((ff**2).sum(0, keepdim=True) + 1e-8).sqrt()
.给sqrt()项加一个很小的余项。