先上代码
def dice_loss(inputs, target, smooth=1e-5):
"""
计算Dice Loss。
参数:
- inputs: 模型输出,形状为 [batch_size, num_classes, height, width]
- target: 真实标签,形状为 [batch_size, 1, height, width],元素值为类别索引
- smooth: 用于数值稳定性的平滑项
返回:
- Dice Loss 的平均值
"""
n_classes = inputs.shape[1] # 类别数
#batch_size = inputs.shape[0] # 批次大小
# 应用softmax得到概率分布
input_softmax = F.softmax(inputs, dim=1)
# 将目标张量从 [batch_size, 1, height, width] 转换为 [batch_size, height, width]
target = target.squeeze(1) # 移除多余的通道维度
# 将目标张量转换为one-hot编码形式
target_one_hot = F.one_hot(target, num_classes=n_classes).permute(0, 3, 1, 2).float()
# 计算交集和并集(按类别)
intersection = (input_softmax * target_one_hot).sum(dim=[0, 2, 3]) # 对每个类别分别计算交集
cardinality = (input_softmax + target_one_hot).sum(dim=[0, 2, 3]) # 对每个类别分别计算并集
# 计算Dice系数
dice_coeff = (2. * intersection + smooth) / (cardinality + smooth)
# 返回1减去Dice系数的均值作为损失
return 1 - dice_coeff.mean()
例子:形状为 [2, 3, 2, 2],表示 batch_size=2, num_classes=3, height=2, width=2
假设经过softmax后 input的概率分布如下
input_softmax=torch.tensor(
[
# Batch 1
[
[[0.002, 0.003], [0.004, 0.005]], # Class 0
[[0.010, 0.012], [0.014, 0.016]], # Class 1
[[0.988, 0.985], [0.982, 0.979]] # Class 2
],
# Batch 2
[
[[0.001, 0.002], [0.003, 0.004]],
[[0.010, 0.012], [0.014, 0.016]],
[[0.989, 0.986], [0.983, 0.980]]
]
]
label 形状为[2,1,2,2]
target = torch.tensor(
[
# Batch 1
[[[0, 1], [2, 1]]],
# Batch 2
[[[1, 2], [0, 2]]]
]
)
具体计算如下:
交集:通过逐元素相乘 input_softmax * target_one_hot 来计算交集:
Class 0:
-
Batch 1: 0.002×1+0.003×0+0.004×0+0.005×0=0.002
-
Batch 2:0.001×0+0.002×0+0.003×1+0.004×0=0.003
-
总和:0.002+0.003=0.005
Class 1:
-
Batch 1:0.010×0+0.012×1+0.014×0+0.016×1=0.028
-
Batch 2:0.010×1+0.012×0+0.014×0+0.016×0=0.010
-
总和:0.028+0.010=0.038
Class 2:
-
Batch 1:0.988×0+0.985×0+0.982×1+0.979×0=0.982
-
Batch 2:0.989×0+0.986×1+0.983×0+0.980×1=1.966
-
总和:0.982+1.966=2.948
intersection = torch.tensor([0.005, 0.038, 2.948])
类似的
并集 :通过逐元素相加 input_softmax + target_one_hot 来计算并集
Class 0:
-
Batch 1: 0.002+1+0.003+0+0.004+0+0.005+0=1.014
-
Batch 2: 0.001+0+0.002+0+0.003+1+0.004+0=1.010
-
总和:1.014+1.010=2.024
Class 1:
-
Batch 1: 0.010+0+0.012+1+0.014+0+0.016+1=2.052
-
Batch 2: 0.010+1+0.012+0+0.014+0+0.016+0=1.052
-
总和:2.052+1.052=3.104
Class 2:
-
Batch 1: 0.988+0+0.985+0+0.982+1+0.979+0=3.934
-
Batch 2: 0.989+0+0.986+1+0.983+0+0.980+1=4.938
-
总和:3.934+4.938=8.872
cardinality = torch.tensor([2.024, 3.104, 8.872])
smooth = 1e-5
dice_coeff = (2. * intersection + smooth) / (cardinality + smooth)
loss = 1 - dice_coeff.mean()
Class 0: Dice 0=(2×0.005+1e−5)/(2.024+1e-5)≈0.00494
Class 1: Dice 1= (2×0.038+1e−5)/(3.104+1e−5)≈0.02448
Class 2: Dice 2= (2×2.948+1e−5)/(8.872+1e−5)≈0.6648
最终的Dice Loss为:
Loss=1− (0.00494+0.02448+0.6648)/3≈0.4369
Dice score的计算
score的计算就是dice_coeff,不用1去减,输出就是一个batch中三个类别的分别的dice