出现Nan的原因
在使用在DeepSEED训练自己的数据集时出现了loss为NAn的因为
一般loss为nan有以下几个原因:
- 梯度爆炸
- 出现除零、对数函数自变量为负值等数学问题
- 出现脏数据
我的是第二种情况 源代码使用的是focalloss的损失
class FocalLoss(nn.Module):
def __init__(self, num_hard=0):
super(FocalLoss, self).__init__()
self.sigmoid = nn.Sigmoid()
self.classify_loss = BinaryFocalLoss(gamma=2, alpha=0.5, size_average=False)
self.regress_loss = nn.SmoothL1Loss()
self.num_hard = num_hard
def forward(self, output, labels, train=True):
batch_size = labels.size(0)
output = output.view(-1, 5)
labels = labels.view(-1, 5)
pos_idcs = labels[:, 0] > 0.5
pos_idcs = pos_idcs.unsqueeze(1).expand(pos_idcs.size(0), 5)
pos_output = output[pos_idcs].view(-1, 5)
pos_labels = labels[pos_idcs].view(-1, 5)
neg_idcs = labels[:, 0] < -0.5
neg_output = output[:, 0][neg_idcs]
neg_labels = labels[:, 0][neg_idcs]
if self.num_hard > 0 and train:
neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.num_hard * batch_size)
neg_prob = self.sigmoid(neg_output)
if len(pos_output) > 0:
pos_prob = self.sigmoid(pos_output[:, 0])
pz, ph, pw, pd = pos_output[:, 1], pos_output[:, 2], pos_output[:, 3], pos_output[:, 4]
lz, lh, lw, ld = pos_labels[:, 1], pos_labels[:, 2], pos_labels[:, 3], pos_labels[:, 4]
regress_losses = [
self.regress_loss(pz, lz),
self.regress_loss(ph, lh),
self.regress_loss(pw, lw),
self.regress_loss(pd, ld)]
regress_losses_data = [l.item() for l in regress_losses]
classify_loss = self.classify_loss.forward(
pos_prob, pos_labels[:, 0]) + self.classify_loss.forward(
neg_prob, neg_labels + 1)
classify_loss = classify_loss / (len(pos_prob) + len(neg_prob))
pos_correct = (pos_prob.data >= 0.5).sum()
pos_total = len(pos_prob)
else:
regress_losses = [0, 0, 0, 0]
classify_loss = self.classify_loss.forward(
neg_prob, neg_labels + 1)
classify_loss = classify_loss / len(neg_prob)
pos_correct = 0
pos_total = 0
regress_losses_data = [0, 0, 0, 0]
classify_loss_data = classify_loss.item()
loss = classify_loss
for regress_loss in regress_losses:
loss += regress_loss
neg_correct = (neg_prob.data < 0.5).sum()
neg_total = len(neg_prob)
return [loss, classify_loss_data] + regress_losses_data + [pos_correct, pos_total, neg_correct, neg_total]
在neg_idcs = labels[:, 0] >0.5这一步出现了问题导致
pos_output = output[pos_idcs].view(-1, 5)中pos的长度为0
debug到网络结构中发现在最后的输出阶段加入了dropout 注销之后就不会出现这种情况了
具体为什么加入了dropout会出现loss为Nan的情况 我还没有得出结论。