keras计算Generalized Dice Loss(GDL)的代码解析

本文详细介绍了如何将Keras中的通用Dice系数和损失函数转换为TensorFlow版本,通过示例代码展示了如何计算类别加权的Dice系数,并用于衡量预测结果与真实标签的一致性。涉及了图像分类任务中常见的评价指标计算方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这个代码在很多地方都能找到,贴出其中一个链接https://zhuanlan.zhihu.com/p/103426335
在这里插入图片描述
其中的代码增加上我个人理解后的注释为:

# keras在安装tensorflow后还需要单独安装
from keras import backend as K	
# y_pred表示预测的结果(Batch, Height, Width, Num_Class)
# y_true表示真实结果 (Batch, Height, Width, Num_Class)
def generalized_dice_coeff(y_true, y_pred):
    Ncl = y_pred.shape[-1]  # N classes 类别总数
    w = K.zeros(shape=(Ncl,))  # (0,0, ... ) (一共Ncl个0)
    w = K.sum(y_true, axis=(0,1,2)) # 在类别这一维度方向上进行求和,求和的结果是
     (Num_Class,), 即计算该类别上所有图片对应的标签之和,换句话说也就是统计图片上属于该类别的像素点的总数。
    w = 1/(w**2+0.000001) # 利用类别对应的像素点的总数来计算权重,总数越多反而权重越小
    # Compute gen dice coef:
    numerator = y_true*y_pred  # (Batch, Height, Width, Num_Class)
    numerator = w*K.sum(numerator,(0,1,2,3)) # (Num_Class,) * (,) = (Num_Class,),  每个类别的权重 *  True Positive总数(或者说是所有像素点对应的预测概率和真实概率之积)
    numerator = K.sum(numerator)  # 加权求和
    denominator = y_true+y_pred # (Batch, Height, Width, Num_Class)
    denominator = w*K.sum(denominator,(0,1,2,3)) # 每个类别的权重 * (所有像素点对应的预测概率和真实概率之和)
    denominator = K.sum(denominator) # 加权求和
    gen_dice_coef = 2*numerator/denominator
    return gen_dice_coef
    
def generalized_dice_loss(y_true, y_pred):
    return 1 - generalized_dice_coeff(y_true, y_pred)

自己更改为tf版本为

下面展示一些 内联代码片

# logits  (B,H,W, C)
# labels (B, H, W,1)
def generalized_dice_coeff(logits, labels, classes):
    Ncl = classes             # 计算类别总数
    w = tf.zeros(shape=(Ncl,))
    label = labels[..., 0]
    onehot_labels = tf.one_hot(label,Ncl)
    w = tf.reduce_sum(onehot_labels, axis=(0,1,2))     # 计算所有的和,而不考虑计算某个维度上的和
    eps = 0.000001
    w = 1/(w**2+eps)

    # Compute gen dice coef:
    numerator = logits * onehot_labels
    numerator = w * tf.reduce_sum(numerator)
    numerator = tf.reduce_sum(numerator)
    denominator = logits + onehot_labels
    denominator = w*tf.reduce_sum(denominator)
    denominator = tf.reduce_sum(denominator)
    gen_dice_coef = 2*numerator/denominator
    return gen_dice_coef

def generalized_dice_loss(logits, labels, classes):
    return 1 - generalized_dice_coeff(logits, labels, classes)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值