图像分割项目中损失函数的选择

前言

在图像分割领域,最基础、最常见的损失当然是交叉熵损失 —— Cross entropy。随着不断的研究,涌现出了许多优于交叉熵损失的,并且在实际场景中,也往往不会在单单使用交叉熵损失了。

场景:实际项目中,通常会有一个常见的问题:样本不均衡

一、focal loss

focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。
  通常情况下,样本不均衡所带来的问题是少样本难以区分(当然也会存在一些本身就很难区分或分割的样本),因此focal loss聚焦于难分样本,在梯度求导时,让难分类样本占主导,因此训练学习过程更加聚焦在难分样本。

思考

   focal loss在训练过程中本身是一个动态选择,并不稳定,这也是为什么有些情形下使用focal loss还不如原本的CE loss。通常来说,为了防止难易样本的频繁变化,应当选取小的学习率

代码如下(示例):

class FocalLoss(nn.Module):
    """
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-1, size_average=True):
        super(FocalLoss, self).__init__()
        self.apply_nonlin = apply_nonlin
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        N=logit.shape[1]
        self.alpha = enet_weighing(target, N).cuda()

        logit = F.softmax(logit, dim=1)
        if self.apply_nonlin is not None:
            logit = self.apply_nonlin(logit)
        num_class = logit.shape[1]
        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        # print(logit.shape, target.shape)
        #
        alpha = self.alpha

        if alpha is None:
            alpha = torch.ones(num_class, 1)
        elif isinstance(alpha, (list, np.ndarray)):
            assert len(alpha) == num_class
            alpha = torch.FloatTensor(alpha).view(num_class, 1)
            alpha = alpha / alpha.sum()
        elif isinstance(alpha, float):
            alpha = torch.ones(num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[self.balance_index] = self.alpha

        # else:
        #     raise TypeError('Not support alpha type')

        if alpha.device != logit.device:
            alpha = alpha.to(logit.device)

        idx = target.cpu().long()

        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + self.smooth
        logpt = pt.log()

        gamma = self.gamma

        alpha = alpha[idx]
        alpha = torch.squeeze(alpha)
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss

# 训练过程
focal = FocalLoss()
FocalLoss1 = focal(out, label) # out:模型输出  label:标签

二、Dice loss

Dice loss适用于样本极度不平衡的情况,一般情况下使用Dice Loss会对反向传播不利,使得训练不稳定(注:在使用DICE loss时,对小目标是十分不利的,因为在只有前景和背景的情况下,小目标一旦有部分像素预测错误,那么就会导致Dice大幅度的变动,从而导致梯度变化剧烈,训练不稳定)。因为,通常是将Dice loss作为辅助损失函数来和主损失函数一起训练,如Dice loss+CE loss 或 Dice loss + Focal loss

代码如下(示例):

import torch
from torch import Tensor
import torch.nn.functional as F

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')

    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter

        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)

    return dice / input.shape[1]


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = True):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

# 训练过程
lossp = dice_loss(F.softmax(out, dim=1).float(),
                 F.one_hot(lb, n_classes).permute(0, 3,1,2).contiguous().float(),  multiclass=True)

三、二分类

图像分割二分类任务一般有两种方式:
(1)和多分类任务一样,只是最后的输出通道num_class设置为2,所以输出的是一个二通道图。二分类标签label是一个单通道图,数值只有0和1两者。为了让模型的输出图不断逼近于abel,会让输出图先经过一个softmax函数,使其数值归一化到(0,1)之间,即让同一位置上两个通道的值加起来等于1。而对于label,会使用onehot编码,转换成了 num_class=2 个通道的图像。然后就可以让输出图和label进行对应的损失计算了。大致流程如下图所示:
在这里插入图片描述
注:

1)二分类任务,经过softmax后,是同一位置的两个通道值之和为1,若是多分类任务,也就是多个通道之和为1。

2)二分类label经过one-hot编码,0变为[0,1],1变为[1,0];若是多分类任务,假设为4分类,那label图里就是 [0,1,2,3] 这四个像素值。则one-hot编码如下:
0 —— 【0,0,0,1】
1 —— 【0,0,1,0】
2 —— 【0,1,0,0】
3 —— 【1,0,0,0】

3)对于CrossEntropyLoss和FocalLoss,其函数内部自带有处理方式,所以无需改动,直接将输出图和label传进去即可,如上面代码:

focal = FocalLoss()
FocalLoss1 = focal(out, label) # out:模型输出  label:标签

loss = torch.nn.CrossEntropyLoss()
loss = loss(out, label)

对于Dice loss,需要自己改动输入方式,如上面代码:

lossp = dice_loss(F.softmax(out, dim=1).float(),
 F.one_hot(lb, n_classes).permute(0, 3, 1, 2).contiguous().float(), multiclass=True)

(2)第二种方式,是显著性目标检测任务中常用的,只输出单通道,即num_class=1。这时是使用sigmoid函数来对输出图进行归一化到(0,1)之间,由于输出图和label都是单通道图,所以可以直接计算损失。可以参考显著性目标检测论文中常用的损失函数:BCE + IOU (BCE关注像素,IOU关注整体结构,两者一起用其实相当于 CE+Dice)

注:使用torch.nn.BCELoss(),需要自己对输出图使用sigmoid处理;若使用BCEWithLogitsLoss(),其函数内部有sigmoid处理,就不需要自己加了。

未完待续

持续记录以后项目中用到的损失函数

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

想要躺平的一枚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值