train函数中出现的loss不能反向传播

本文介绍了在图像分割任务中使用MS-SSIM+L1损失函数的方法,并解决了因使用argmax导致训练失效的问题。通过正确地将网络输出转换为目标形式,使得损失能够有效反向传播。

在图像分割任务中,我换了个MS-SSIM+L1损失函数,该损失函数的要求是输入两张图片,我的网络输出的是经过softmax的概率值。

为了将输出复现为图片与标签一同输入到损失中进行计算,我开始的处理是这样的:

#模型输出
outputs = self.model(images)

#复现预测图为灰度图像
pred = torch.argmax(outputs[0], dim=1).unsqueeze(1) #[1,1,H,W]
target = targets.unsqueeze(1) #[1,1,H,W]
loss_dict = dict(loss=self.criterion(pred, target))
losses = sum(loss for loss in loss_dict.values())

# xxxxxx somethings
losses.requires_grad_(True)
losses.backward()

上述train过程中使用了argmax函数,影响了losses反向传播的过程,导致训练没有任何效果(损失值固定不变)


这里的解决代码为:

#模型输出
outputs = self.model(images)

#复现预测图为灰度图像
target = F.one_hot(targets, num_classes=2).permute(0,3,1,2).float()
loss_dict = dict(loss=self.criterion(outputs[0], target))
losses = sum(loss for loss in loss_dict.values())

# xxxxxx somethings
losses.requires_grad_(True)
losses.backward()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Star星屹程序设计

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值