保姆级详解和实现Dice Loss/Score

先上代码

def dice_loss(inputs, target, smooth=1e-5):
    """
    计算Dice Loss。
    
    参数:
    - inputs: 模型输出,形状为 [batch_size, num_classes, height, width]
    - target: 真实标签,形状为 [batch_size, 1, height, width],元素值为类别索引
    - smooth: 用于数值稳定性的平滑项
    
    返回:
    - Dice Loss 的平均值
    """
    n_classes = inputs.shape[1]  # 类别数
    #batch_size = inputs.shape[0]  # 批次大小
    
    # 应用softmax得到概率分布
    input_softmax = F.softmax(inputs, dim=1)
    
    # 将目标张量从 [batch_size, 1, height, width] 转换为 [batch_size, height, width]
    target = target.squeeze(1)  # 移除多余的通道维度
    
    # 将目标张量转换为one-hot编码形式
    target_one_hot = F.one_hot(target, num_classes=n_classes).permute(0, 3, 1, 2).float()
    
    # 计算交集和并集(按类别)
    intersection = (input_softmax * target_one_hot).sum(dim=[0, 2, 3])  # 对每个类别分别计算交集
    cardinality = (input_softmax + target_one_hot).sum(dim=[0, 2, 3])   # 对每个类别分别计算并集
    
    # 计算Dice系数
    dice_coeff = (2. * intersection + smooth) / (cardinality + smooth)
    
    # 返回1减去Dice系数的均值作为损失
    return 1 - dice_coeff.mean()

例子:形状为 [2, 3, 2, 2],表示 batch_size=2, num_classes=3, height=2, width=2

假设经过softmax后 input的概率分布如下

input_softmax=torch.tensor(
[
    # Batch 1
    [
        [[0.002, 0.003], [0.004, 0.005]],  # Class 0
        [[0.010, 0.012], [0.014, 0.016]],  # Class 1
        [[0.988, 0.985], [0.982, 0.979]]   # Class 2
    ],
    # Batch 2
    [
        [[0.001, 0.002], [0.003, 0.004]],
        [[0.010, 0.012], [0.014, 0.016]],
        [[0.989, 0.986], [0.983, 0.980]]
    ]
]

label 形状为[2,1,2,2]

target = torch.tensor(
    [
        # Batch 1
        [[[0, 1], [2, 1]]],
        # Batch 2
        [[[1, 2], [0, 2]]]
    ]
)

具体计算如下:

交集:通过逐元素相乘 input_softmax * target_one_hot 来计算交集:

Class 0:

  • Batch 1: 0.002×1+0.003×0+0.004×0+0.005×0=0.002

  • Batch 2:0.001×0+0.002×0+0.003×1+0.004×0=0.003

  • 总和:0.002+0.003=0.005

Class 1:

  • Batch 1:0.010×0+0.012×1+0.014×0+0.016×1=0.028

  • Batch 2:0.010×1+0.012×0+0.014×0+0.016×0=0.010

  • 总和:0.028+0.010=0.038

Class 2:

  • Batch 1:0.988×0+0.985×0+0.982×1+0.979×0=0.982

  • Batch 2:0.989×0+0.986×1+0.983×0+0.980×1=1.966

  • 总和:0.982+1.966=2.948

intersection = torch.tensor([0.005, 0.038, 2.948])

类似的 

并集 :通过逐元素相加 input_softmax + target_one_hot 来计算并集

Class 0:

  • Batch 1: 0.002+1+0.003+0+0.004+0+0.005+0=1.014

  • Batch 2: 0.001+0+0.002+0+0.003+1+0.004+0=1.010

  • 总和:1.014+1.010=2.024

Class 1:

  • Batch 1: 0.010+0+0.012+1+0.014+0+0.016+1=2.052

  • Batch 2: 0.010+1+0.012+0+0.014+0+0.016+0=1.052

  • 总和:2.052+1.052=3.104

Class 2:

  • Batch 1: 0.988+0+0.985+0+0.982+1+0.979+0=3.934

  • Batch 2: 0.989+0+0.986+1+0.983+0+0.980+1=4.938

  • 总和:3.934+4.938=8.872

cardinality = torch.tensor([2.024, 3.104, 8.872])
smooth = 1e-5
dice_coeff = (2. * intersection + smooth) / (cardinality + smooth)
loss = 1 - dice_coeff.mean()

Class 0: Dice 0=(2×0.005+1e−5)/(2.024+1e-5)≈0.00494

Class 1: Dice 1= (2×0.038+1e−5)/(3.104+1e−5)≈0.02448

Class 2: Dice 2= (2×2.948+1e−5)/(8.872+1e−5)≈0.6648

最终的Dice Loss为:

Loss=1− (0.00494+0.02448+0.6648)/3≈0.4369

Dice score的计算

score的计算就是dice_coeff,不用1去减,输出就是一个batch中三个类别的分别的dice

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值