loss突然变为nan?你可能踩了sqrt()的坑

这几天在写一个网络模型,需要自定义损失函数,关于gps距离误差计算的,写完之后,噩梦开始了…
训练过程中loss总是莫名其妙的突然变为nan,网上查阅了许多资料,做了各种尝试,比如调整学习率、调整batch大小、调整网络复杂度、梯度裁剪、过滤脏数据、检查是否存在除0、log(0), 加入BatchNormalization层等,无奈还是会出现loss变为nan的问题。
后来分析loss函数本身,发现唯一可能出现问题的地方是下面这行代码里的tf.sqrt()函数:

return K.mean(tf.sqrt(tf.add(tf.square(lx), tf.square(ly))))

于是,又上网查发现tensorflow或者pytorch在loss函数中使用sqrt可能导致loss训练变为nan的问题,原因如下:
sqrt()即x^1/2,在x=0处不可导,前向传播过程中,loss的计算不会出问题,但在反向传播进行梯度计算的时候可能会遇到在0处求导的情况,这也是loss突然变为nan的原因,在sqrt()添加一个极小数之后得到解决:

return K.mean(tf.sqrt(tf.add(tf.add(tf.square(lx), tf.square(ly)), 1e-10)))
pytorch部分代码如下:train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True) # 3、将数据输入mixup_fn生成mixup数据 samples, targets = mixup_fn(data, target) # 4、将上一步生成的数据输入model,输出预测结果,再计算loss output = model(samples) # 5、梯度清零(将loss关于weight的数变成0) optimizer.zero_grad() # 6、若使用混合精度 if use_amp: with torch.cuda.amp.autocast(): # 开启混合精度 loss = torch.nan_to_num(criterion_train(output, targets)) # 计算loss scaler.scale(loss).backward() # 梯度放大 torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks or global_forward_hooks or global_forward_pre_hooks): return forward_call(input, **kwargs) class LDAMLoss(nn.Module): def init(self, cls_num_list, max_m=0.5, weight=None, s=30): super(LDAMLoss, self).init() m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) m_list = m_list * (max_m / np.max(m_list)) m_list = torch.cuda.FloatTensor(m_list) self.m_list = m_list assert s > 0 self.s = s self.weight = weight def forward(self, x, target): index = torch.zeros_like(x, dtype=torch.uint8) index.scatter(1, target.data.view(-1, 1).type(torch.int64), 1) index_float = index.type(torch.cuda.FloatTensor) batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) batch_m = batch_m.view((-1, 1)) x_m = x - batch_m output = torch.where(index, x_m, x) return F.cross_entropy(self.soutput, target, weight=self.weight) 报错:RuntimeError: Expected index [112, 1] to be smaller than self [16, 7] apart from dimension 1 帮我看看如何修改源代码
06-10
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值