Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估
前言
最近看了focal loss的文章,正好在做文本分类的项目,一个是Sentence Bert句子匹配,一个是网易云音乐评论的情绪分类。本人用的框架是tensorflow2.0,所以想尝试实践一下focal loss,但是翻遍了网上的文章,不是代码报错就是错误实现。最后就自己根据focal loss的公式写了一个,试跑了代码确认无误。
tensorflow :2.0.0(GPU上跑)
transformers :3.1
二分类 focal loss
from tensorflow.python.ops import array_ops
def binary_focal_loss(target_tensor,prediction_tensor, alpha=0.25, gamma=2):
zeros = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)
target_tensor = tf.cast(target_tensor,prediction_tensor.dtype)
pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - prediction_tensor, zeros)
neg_p_sub =<