图像多label分割的dice损失函数

目标:

实现图像中多个物体的分割,多个物体的标注方式为0,1,2,3,,,,,0表示背景,1表示一种物体,2表示另一种物体,假设我们现在的分割任务里面有5个目标需要,如肺叶分割,5个肺叶的标注方式为:0表示背景,1表示右上叶,2表示右中叶,3表示右下叶,4表示左上页,5表示左下叶。

前提:

首先我们拿到的标注应该是单通道的,如大小为96*96*64*1,但在网络模型设计中输出通道数应该是分割类别数,即网络模型最后的输出大小为96*96*64*6(这里最后一维6指背景+5个目标),这样在训练的时候网络输出和标注之间大小不匹配,不能进行损失值计算,所以需要将原来单通道的标注图像变成6个通道,这就需要用到One-hot编码,在每个通道上(即将每一类进行二值化)。

正文:

One-hot编码:

以上是肺叶的标注图像,可以看到背景像素值是0,其余叶标注值分别为1,2,3,4,5,以矩阵表示大概为(二维图像举例,并不是这个图的真正矩阵):

[00000000
 01211110
 01222330
 01123340
 01223440
 01233550
 02234450
 02334550]

经过One-hot编码后变成6个通道:

[11111111      
 10000001
 10000001
 10000001
 10000001
 10000001
 10000001
 10000001] 
[00000000
 01011110
 01000000
 01100000
 01000000
 01000000
 00000000
 00000000]
[00000000
 00100000
 00111000
 00010000
 00110000
 00100000
 01100000
 01000000]
[00000000
 00000000
 00000110
 00001100
 00001000
 00011000
 00010000
 00110000]
[00000000
 00000000
 00000000
 00000010
 00000110
 00000000
 00001100
 00001000]
[00000000
 00000000
 00000000
 00000000
 00000000
 00000110
 00000010
 00000110]

One-hot代码实现:

def mask_to_onehot(mask, palette):
    """
    Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
    hot encoding vector, C is usually 1, and K is the number of segmented class.
    eg:
    mask:单通道的标注图像
    palette:[[0],[1],[2],[3],[4],[5]]
    """
    semantic_map = []
    for colour in palette:
        equality = np.equal(mask, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    return semantic_map

多label分割dice损失函数代码实现:

将背景通道的权重设置为0

def categorical_dice(self, Y_pred, Y_gt, weight_loss):
    """
    multi label dice loss with weighted
    Y_pred: [None, self.image_depth, self.image_height, self.image_width,
                                                       self.numclass],Y_pred is softmax result
    Y_gt:[None, self.image_depth, self.image_height, self.image_width,
                                                       self.numclass],Y_gt is one hot result
    weight_loss: numpy array of shape (C,) where C is the number of classes,eg:[0,1,1,1,1,1]
    :return:
    """
    # print('Y_pred.shape',Y_pred.shape)
    # print('Y_gt.shape',Y_gt.shape)
    weight_loss = np.array(weight_loss)
    smooth = 1.e-5
    smooth_tf = tf.constant(smooth, tf.float32)
    Y_pred = tf.cast(Y_pred, tf.float32)
    Y_gt = tf.cast(Y_gt, tf.float32)
    # Compute gen dice coef:
    numerator = Y_gt * Y_pred
    # print('intersection shape',numerator.shape) intersection shape (?, 64, 96, 96, 6)
    numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
    # print('after reduce_sum intersection shape', numerator.shape) after reduce_sum intersection shape (?, 6)
    denominator = Y_gt + Y_pred
    denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
    gen_dice_coef = tf.reduce_mean(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
    # print('gen_dice_coef',gen_dice_coef.shape) gen_dice_coef (6,)
    loss = -tf.reduce_mean(weight_loss * gen_dice_coef)
    return loss

最后可视化预测结果的时候,需要将6个通道的预测结果再转为单通道:

One-hot实现:

def onehot_to_mask(self, mask, palette):
    """
    Converts a mask (H, W, K) to (H, W, C)
	K is the number of segmented class,C is usually 1  
    """
    x = np.argmax(mask, axis=-1)
    colour_codes = np.array(palette)
    x = np.uint8(colour_codes[x.astype(np.uint8)])
    return x

 

 

 

 

 

 

  • 9
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
多分类Dice损失函数是一种用于语义分割任务的损失函数。它是基于Dice系数的度量,用于衡量模型预测结果与真实标签之间的相似度。Dice损失函数可以将预测结果与真实标签进行对比,并优化模型参数以最大化Dice系数。 在多分类任务中,每个类别都有一个对应的Dice损失函数。常见的做法是使用多个Dice损失函数对每个类别进行独立的分割,然后将这些损失函数整合到一个总的损失函数中。这个总的损失函数被称为Generalized Dice损失函数。 Generalized Dice损失函数的计算方式如下: 1. 计算每个类别的Dice系数:对于每个类别i,将模型预测结果与真实标签进行相交运算并计算相交区域的像素数量,然后计算相交区域的大小与预测区域和真实区域大小之和的比值,得到Dice系数Di。 2. 计算类别权重:对于每个类别i,计算其在真实标签中的像素数量与总像素数量的比值,得到类别权重Wi。 3. 将Dice系数与类别权重相乘并求和:将每个类别的Dice系数Di与对应的类别权重Wi相乘,并将所有类别的结果求和,得到Generalized Dice损失函数L。 通过最小化Generalized Dice损失函数,模型可以更好地适应多分类语义分割任务,提高预测结果的准确性。 参考文献: - 引用: 【损失函数合集】超详细的语义分割中的Loss大盘点 - 引用: Tensorflow入门教程(四十七)——语义分割损失函数总结 - 引用: 论文地址:A survey of loss functions for semantic segmentation code地址:Semantic-Segmentation-Loss-Functions

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值