Focal_loss 的实现

本文探讨了Focal Loss的实现,分别介绍了在TensorFlow 1.13和Pytorch版本下的实现方法。通过设置alpha为0.5,gamma为0,验证了实现与交叉熵损失函数的一致性。
摘要由CSDN通过智能技术生成

TensorFlow 1.13版本

def focal_loss(logits,onehot_labels,  gamma=2.0, alpha=4.0):
    """
    focal loss for multi-classification
    FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
    Notice: logits is probability after softmax
    gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
    d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
    Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
    Focal Loss for Dense Object Detection, 130(4), 485–491.
    https://doi.org/10.1016/j.ajodo.2005.02.022
    :param labels: ground truth labels, shape of [batch_size]
    :param logits: model's output, shape of [batch_size, num_cls]
    :param gamma:
    :param alpha:正样本*alpha 负样本*(1-alpha) alpha越高 正样本训练的就越好,负样本训练的就越差
    :return: shape of [batch_size]
    """
    epsilon = 1.e-9
    logits = tf.convert_to_tensor(logits, tf.float32)
    model_out = tf.add(logits, epsilo
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值