图像分割任务中,Dice Loss
中的Dice
计算 与 评估指标中Dice
计算的区别 | TL;DR
: 训练 Loss可以近似,但评估指标要务必精确
最近在搭建自己的训练框架时发现:**Dice loss 和 Dice评估指标 虽然都是基于Dice系数计算,但是使用的时候由于身处计算环境和目的不一样,会有一些区别。**区别主要在于:
- 几乎所有的
Loss
都是以目标最小化为目的,所以DiceLoss
最后会有一个 最小化的步骤Dice Loss
从原理上讲,求dice
需要先将类别概率转换为类别标签,然后求交求并。但由于标签转换时会涉及argmax
操作,而argmax
操作会导致计算后的loss
失去梯度信息(argmax不可导)。所以为了不丢失梯度信息,DiceLoss
使用的是原始的预测概率分布。 个人认为DiceLoss
求得是一个近似值,其实也可以理解,作为引导训练的指标,只要方向对即可,不需要很精确。Dice系数
,作为一个评估模型训练性能的指标,需要务必精确,因此,在计算Dice系数
时,需要先将预测结果的原始预测概率先通过argmax
转换为预测标签,然后再求交求并求结果。
Ref:
https://www.cnblogs.com/lfri/p/15552933.html
https://discuss.pytorch.org/t/torch-argmax-cause-loss-backward-dont-work/64782
Dice Loss
Dice Loss
是一种常见的用于图像分割任务的, 基于Dice
系数的度量集合相似度的损失函数,特别是对于类别不平衡的数据时。
计算公式如下:
D i c e L o s s = 1 − 2 × ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice Loss = 1- \frac {2 \times \vert X \cap Y \vert} {\vert X \vert + \vert Y \vert} DiceLoss=1−∣X∣+∣Y∣2×∣X∩Y∣
假如预测标签维度为[batch, 4, D, W, H]
, 真实标签维度为 [batch, D, W, H]
,其计算代码如下:
class DiceLoss(nn.Module):
"""Dice loss for image segmentation"""
def __init__(self, smooth=1e-6):
"""
初始化函数:
:param smooth: 平滑变量,防止分母为0
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, y_pred, y_true):
"""
前向反馈
:param y_pred: 预测值 [batch, 4, D, W, H]
:param y_true: 真实值 [batch, D, W, H]
"""
# y_pred = torch.argmax(y_pred, dim=1).to(dtype=torch.int64) # ⚠️argmax会使得loss失去梯度信息
# y_pred = F.one_hot(y_pred, num_classes=4).permute(0, 4, 1, 2, 3).float()
y_true = F.one_hot(y_true, num_classes=4).permute(0, 4, 1, 2, 3).float()
# 计算Dice 系数
intersection = (y_pred * y_true).sum(dim=(2, 3, 4))
union = y_pred.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
Diceloss = (2 * intersection + self.smooth) / (union + self.smooth)
return 1 - Diceloss.mean()
Dice
系数评估指标
Dice 系数是两个集合相似度的一种衡量指标,范围从0到1。
计算公式:
D i c e = 2 × ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice = \frac {2 \times \vert X \cap Y \vert} {\vert X \vert + \vert Y \vert} Dice=∣X∣+∣Y∣2×∣X∩Y∣
Dice系数越高,表示标签结果与真实标签越接近。
假如预测标签维度为[batch, 4, D, W, H]
, 真实标签维度为 [batch, D, W, H]
,其计算代码如下:
def dice_coefficient(self, y_pred, y_true):
"""
计算Dice 系数
:param y_pred: 预测标签
:param y_true: 真实标签
:return: Dice 系数
"""
# 预处理
y_pred = torch.argmax(y_pred, dim=1).to(dtype=torch.int64) # 降维,选出概率最大的类索引值,即获取预测类别标签
y_pred = F.one_hot(y_pred, num_classes=4).permute(0, 4, 1, 2, 3).float() # one-hot
y_true = F.one_hot(y_true, num_classes=4).permute(0, 4, 1, 2, 3).float() # ont-hot
# 计算Dice 系数
intersection = (y_pred * y_true).sum(dim=(2, 3, 4))
union = y_pred.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
dice = (2*(intersection + self.smooth)) / (union + self.smooth)
return dice.mean(dim=0)