Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估

本文介绍了如何在TensorFlow2.0中实现二分类和多分类的focal loss,并在 Sentence Bert 和网易云音乐评论情绪分类任务中进行测试。实验结果显示,focal loss能有效抑制过拟合,提高二分类模型的性能,但在多分类任务中效果不明显,适用性取决于数据集的类别分布。
摘要由CSDN通过智能技术生成

Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估


前言

最近看了focal loss的文章,正好在做文本分类的项目,一个是Sentence Bert句子匹配,一个是网易云音乐评论的情绪分类。本人用的框架是tensorflow2.0,所以想尝试实践一下focal loss,但是翻遍了网上的文章,不是代码报错就是错误实现。最后就自己根据focal loss的公式写了一个,试跑了代码确认无误。

tensorflow :2.0.0(GPU上跑)
transformers :3.1


二分类 focal loss

二分类 focal loss公式

from tensorflow.python.ops import array_ops
def binary_focal_loss(target_tensor,prediction_tensor, alpha=0.25, gamma=2):
    zeros = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)
    target_tensor = tf.cast(target_tensor,prediction_tensor.dtype)
    pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - prediction_tensor, zeros)
    neg_p_sub =<
  • 4
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值