关于多标签分类任务的损失函数和评价指标的一点理解

关于多标签分类任务的损失函数和评价指标的一点理解

之前有接触到多标签分类任务,但是主要关注点都放在模型结构中,最近关于多标签分类任务进行了一个讨论,发现其中有些细节不是太清楚,经过查阅资料逐渐理解,现在此记录。

多标签分类任务损失函数

在二分类、多分类任务中通常使用交叉熵损失函数,即Pytorch中的CrossEntorpy,但是在多标签分类任务中使用的是BCEWithLogitsLoss函数。

BCEWithLogitsLoss与CrossEntorpy的不同之处在于计算样本所属类别概率值时使用的计算函数不同:
1)CrossEntorpy使用softmax函数,即将模型输出值作为softmax函数的输入,进而计算样本属于每个类别的概率,softmax计算得到的类别概率值加和为1。
2)BCEWithLogitsLoss使用sigmoid函数,将模型输出值作为sigmoid函数的输入,计算得到的多个类别概率值加和不一定为1。

共同点是计算概率值后都继续计算预测概率值和真实标签之间的交叉熵作为最终的损失函数值。

为什么在多标签任务中使用BCEWithLogitsLoss(sigmoid)函数呢?个人理解如下:
1)二分类/多分类任务是在两个/多个类别中取出一个类别,并且各个类别之间是互斥的,因此要保证多个类别的概率值加和为1(在类别概率值加和为1的情况下,一个类别概率值增加时必然有其他类别概率值减小,体现了各个类别之间的互斥),并且最终取出概率值最大的类别。
2)多分类任务是在多个类别中取出一个或多个类别,各个类别之间不互斥,因此无需保证各类别概率加和为1,只需要计算样本属于每一个类别的概率,如果样本属于某一类别的概率高于阈值则代表样本属于该类别(此时类别概率值加和不一定为1),例如样本A经过BCEWithLogitsLoss函数计算后得到属于类别1、类别2、类别3和类别4的概率值分别为[0.6,0.7,0.3,0.4],阈值为0.5,则样本A同时属于类别1和类别2。

使用方法如下:

import torch
import torch.nn as nn

#创建输入
input = torch.tensor([[0.1,0.2,0.3],[0.4,0.5,0.6]])#共有两个输入样本
target = torch.tensor([[1,0,1],[0,1,1]])#每个样本的标签值都是多标签

#创建模型并计算
model = nn.Linear(input_dim = 3,hidden_dim=5)#此处随便写一个模型示意
model_out = model(input)

#计算损失函数值
loss = torch.nn.BCEWithLogitsLoss(model_out,target)

参考:
https://www.jianshu.com/p/ac3bec3dde3e
https://gombru.github.io/2018/05/23/cross_entropy_loss/

多标签任务评价指标

在二分类、多分类任务中评价指标使用sklearn中的classification_report,可以直观输出每个类别的准确率、召回率和F1值,如下图:
在这里插入图片描述
在多标签分类中仍然可以使用sklearn中的classification_report,结果如下图所示
在这里插入图片描述
但是在多标签分类任务的输出结果中增加了“micro avg”和“samples avg”两项指标值,减少了“accuracy”指标值,个人感觉“samples avg”指标值比较具有参考价值,是站在样本角度衡量模型效果,例如样本A预测标签是[1,1,0,0,1],真实标签是[1,0,1,0,1],samples avg就是一种样本标签预测正确程度的衡量指标,当然也可以采用预测正确标签在总标签中的所占比例代表样本标签预测正确程度(个人认为),不知道samples avg是否还有其他意义。
参考:
https://blog.csdn.net/weixin_26731327/article/details/109122687
https://www.it1352.com/1586638.html

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值