图像分割任务中,`Dice Loss`中的`Dice`计算 与 评估指标中`Dice`计算的区别 | `TL;DR`: 训练 Loss可以近似,但评估指标要务必精确

图像分割任务中,Dice Loss中的Dice计算 与 评估指标中Dice计算的区别 | TL;DR: 训练 Loss可以近似,但评估指标要务必精确

最近在搭建自己的训练框架时发现:**Dice loss 和 Dice评估指标 虽然都是基于Dice系数计算,但是使用的时候由于身处计算环境和目的不一样,会有一些区别。**区别主要在于:

  1. 几乎所有的Loss都是以目标最小化为目的,所以DiceLoss最后会有一个 最小化的步骤
  2. Dice Loss从原理上讲,求dice需要先将类别概率转换为类别标签,然后求交求并。但由于标签转换时会涉及argmax操作,而argmax操作会导致计算后的loss失去梯度信息(argmax不可导)。所以为了不丢失梯度信息, DiceLoss使用的是原始的预测概率分布。 个人认为DiceLoss求得是一个近似值,其实也可以理解,作为引导训练的指标,只要方向对即可,不需要很精确。
  3. Dice系数,作为一个评估模型训练性能的指标,需要务必精确,因此,在计算Dice系数时,需要先将预测结果的原始预测概率先通过argmax转换为预测标签,然后再求交求并求结果。

Ref:

https://www.cnblogs.com/lfri/p/15552933.html

https://discuss.pytorch.org/t/torch-argmax-cause-loss-backward-dont-work/64782

Dice Loss

Dice Loss 是一种常见的用于图像分割任务的, 基于Dice系数的度量集合相似度的损失函数,特别是对于类别不平衡的数据时。

计算公式如下:

D i c e L o s s = 1 − 2 × ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice Loss = 1- \frac {2 \times \vert X \cap Y \vert} {\vert X \vert + \vert Y \vert} DiceLoss=1X+Y2×XY

假如预测标签维度为[batch, 4, D, W, H], 真实标签维度为 [batch, D, W, H] ,其计算代码如下:

class DiceLoss(nn.Module):
    """Dice loss for image segmentation"""
    def __init__(self, smooth=1e-6):
        """
        初始化函数:
        :param smooth: 平滑变量,防止分母为0
        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        """
        前向反馈
        :param y_pred: 预测值 [batch, 4, D, W, H]
        :param y_true: 真实值 [batch, D, W, H]
        """
        # y_pred = torch.argmax(y_pred, dim=1).to(dtype=torch.int64)   # ⚠️argmax会使得loss失去梯度信息
        # y_pred = F.one_hot(y_pred, num_classes=4).permute(0, 4, 1, 2, 3).float()
        y_true = F.one_hot(y_true, num_classes=4).permute(0, 4, 1, 2, 3).float()

        # 计算Dice 系数
        intersection = (y_pred * y_true).sum(dim=(2, 3, 4))
        union = y_pred.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
        Diceloss = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - Diceloss.mean()

Dice系数评估指标

Dice 系数是两个集合相似度的一种衡量指标,范围从0到1。

计算公式:

D i c e = 2 × ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice = \frac {2 \times \vert X \cap Y \vert} {\vert X \vert + \vert Y \vert} Dice=X+Y2×XY

Dice系数越高,表示标签结果与真实标签越接近。

假如预测标签维度为[batch, 4, D, W, H], 真实标签维度为 [batch, D, W, H] ,其计算代码如下:

    def dice_coefficient(self, y_pred, y_true):
        """
        计算Dice 系数
        :param y_pred: 预测标签
        :param y_true: 真实标签
        :return: Dice 系数
        """
        # 预处理

        y_pred = torch.argmax(y_pred, dim=1).to(dtype=torch.int64) # 降维,选出概率最大的类索引值,即获取预测类别标签
        y_pred = F.one_hot(y_pred, num_classes=4).permute(0, 4, 1, 2, 3).float() # one-hot
        y_true = F.one_hot(y_true, num_classes=4).permute(0, 4, 1, 2, 3).float() # ont-hot
        
        # 计算Dice 系数
        intersection = (y_pred * y_true).sum(dim=(2, 3, 4))
        union = y_pred.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
        dice = (2*(intersection + self.smooth)) / (union + self.smooth)
        return dice.mean(dim=0)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值