在图像分割任务中,我换了个MS-SSIM+L1损失函数,该损失函数的要求是输入两张图片,我的网络输出的是经过softmax的概率值。
为了将输出复现为图片与标签一同输入到损失中进行计算,我开始的处理是这样的:
#模型输出
outputs = self.model(images)
#复现预测图为灰度图像
pred = torch.argmax(outputs[0], dim=1).unsqueeze(1) #[1,1,H,W]
target = targets.unsqueeze(1) #[1,1,H,W]
loss_dict = dict(loss=self.criterion(pred, target))
losses = sum(loss for loss in loss_dict.values())
# xxxxxx somethings
losses.requires_grad_(True)
losses.backward()
上述train过程中使用了argmax函数,影响了losses反向传播的过程,导致训练没有任何效果(损失值固定不变)
这里的解决代码为:
#模型输出
outputs = self.model(images)
#复现预测图为灰度图像
target = F.one_hot(targets, num_classes=2).permute(0,3,1,2).float()
loss_dict = dict(loss=self.criterion(outputs[0], target))
losses = sum(loss for loss in loss_dict.values())
# xxxxxx somethings
losses.requires_grad_(True)
losses.backward()

本文介绍了在图像分割任务中使用MS-SSIM+L1损失函数的方法,并解决了因使用argmax导致训练失效的问题。通过正确地将网络输出转换为目标形式,使得损失能够有效反向传播。
最低0.47元/天 解锁文章
1086

被折叠的 条评论
为什么被折叠?



