深度学习| DiceLoss解决图像数据不平衡问题

本文讨论了图像数据不平衡在二分类和多分类问题中的影响,特别是医学图像处理中的应用。着重介绍了交叉熵损失函数如何导致问题出现,以及DiceLoss如何通过mask操作关注重要类别来缓解不平衡。DiceLoss的不稳定性和使用注意事项也进行了探讨。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

图像数据不平衡问题

图像数据不平衡:在进行图像分割时,二分类问题中,背景过大,前景过小;多分类问题中,某一类别的物体体积过小。在很多图像数据的时候都会遇到这个情况,尤其是在医学图像处理的时候,例如分割官腔轮廓、识别肿瘤、识别斑块等等。

图像数据不平衡会产生的问题:占据绝大多数的类别会支配模型的训练过程,导致模型只需要把占比大的类别分准损失就收敛了,占比小的类别反而分的很差,而我们很多时候需要分准的其实是占比小的类别。

这个问题的产生也和交叉熵损失函数有关。

交叉熵损失函数:通常进行图像分割的时候都会使用交叉熵损失函数,交叉熵的特点就是“平等”地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

Dice Loss

公式

之前在介绍深度学习指标的时候,提到过Dice。

Dice可以计算集合的相似程度,取值范围在[0,1],公式如下所示:
D i c e ( X , Y ) = 2 ∗ ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice(X,Y)=\frac{2*|X\cap Y|}{|X|+|Y|} Dice(X,Y)=X+Y2XY

Dice Loss表达式:
1 − D i c e ( X , Y ) = 1 − 2 ∗ ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ 1-Dice(X,Y)=1-\frac{2*|X\cap Y|}{|X|+|Y|} 1Dice(X,Y)=1X+Y2XY

为什么能解决图像数据不平衡问题

进行图像二分类问题的时候,X看作是Label(标签)像素点值集合,Y看作是Prediction(预测)像素点值集合,将前景真实值设为1,背景真实值设为0。这样在计算Dice的时候,求X和Y交集和并集就会把真实值为0的背景过滤掉,相当于做了个mask(掩码)操作,从而使得训练更关注我们要前景。

在这里插入图片描述

使用注意

在实际使用的时候,损失函数并不会单纯使用Dice Loss,通常都会和其他Loss结合起来用,会给其他Loss和Dice Loss分别上不同的权重作为损失函数。

为什么不只简单Dice Loss:在训练模型的时候,我们通常都是要把模型训练到损失收敛才停下。而Dice Loss本身并不稳定,Dice Loss是mask了背景的,在前景目标很小的情况下,一旦有少部分分类错误,就会导致Dice Loss产生严重的震荡。而且背景预测不正确,但是前景都预测涵盖了,会导致Loss也非常低,但实际上预测的并不对。

### 处理样本数据平衡的损失函数 在面对样本数据平衡的情况时,传统的损失函数可能无法有效应对少数类别的样本问题。为了改善这种情况,可以考虑使用特定设计来处理平衡数据集的损失函数。 #### 类别加权交叉熵损失 (Weighted Cross Entropy Loss) 类别加权交叉熵损失是对标准二元或多元交叉熵损失的一个改进版本。通过对同类别的样本赋予同的权重,使得模型更加关注于较少数量的类别。具体而言,在计算损失的过程中给稀有事件更高的惩罚系数,从而让模型更倾向于正确识别这些重要但罕见的例子[^4]。 ```python import torch.nn as nn class_weight = [0.1, 0.9] # 假设有两个类别,其中第二类较为少见 criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weight)) ``` #### Focal Loss Focal loss 是一种专门为了解决极度不均衡数据分布而提出的新型损失函数。其主要特点是引入了一个调节因子 \((1-p_t)^γ\) ,当某个样本被容易地分对时 (\(p_t>0.5\)) , 这个项将会变得很小;而对于那些难以区分的样本,则会使该项增大,进而增加它们在整个批次中的影响程度。这样做的好处是可以使训练过程中更多地聚焦于困难负例上而是简单正例,有助于提高整体性能[^3]。 ```python def focal_loss(input_values, gamma=2.0): p = torch.sigmoid(input_values) ce_loss = F.binary_cross_entropy_with_logits(input_values, target, reduction="none") pt = p * target + (1 - p) * (1 - target) FL = ((1-pt)**gamma) * ce_loss return FL.mean() ``` #### Dice Loss 和 Tversky Loss Dice loss 及其变体Tversky loss 主要应用于医学图像分割等领域,但在其他类型的分类任务中同样适用。这两种方法都基于集合论的思想构建而成,旨在最大化预测结果与实际标签之间的交并比(IOU),特别适合用来缓解因类别间数目差异过大而导致的问题[^5]。 ```python def dice_loss(inputs, targets, smooth=1e-6): inputs = torch.sigmoid(inputs).view(-1) targets = targets.view(-1) intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) return 1 - dice ```
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值