multi-label分类,loss一直增大

212 篇文章 5 订阅 ¥19.90 ¥99.00
83 篇文章 2 订阅 ¥19.90 ¥99.00
博客探讨了在多标签分类任务中遇到的loss值持续增大的问题,特别是在Pytorch和Tensorflow框架下。文章提到,label以[batch_size, num_class]的形状表示,例如一个样本可能属于10类中的3类。作者指出,当前的处理方式可能不适用于这种情况,需要寻找合适的解决方案。" 137415047,22897708,微信小程序自定义tabBar实战与前端面试技巧,"['微信小程序', '前端开发', '面试准备']
摘要由CSDN通过智能技术生成

label为[batch_size, num_class]
logits为[batch_size, num_class]

每个label为比如[0,0,1,0,0,0,1,0,1,0],就是10类有3类正确

不能用tf.nn.softmax_cross_entropy_with_logits

Pytorch使用torch.nn.BCEloss
Tensorflow使用tf.losses.sigmoid_cross_entropy

Multi-label problems arise in various domains such as multi-topic document categorization, pro- tein function prediction, and automatic image annotation. One natural way to deal with such problems is to construct a binary classifier for each label, resulting in a set of independent bi- nary classification problems. Since multiple labels share the same input space, and the seman- tics conveyed by different labels are usually correlated, it is essential to exploit the correlation information contained in different labels. In this paper, we consider a general framework for ex- tracting shared structures in multi-label classification. In this framework, a common subspace is assumed to be shared among multiple labels. We show that the optimal solution to the proposed formulation can be obtained by solving a generalized eigenvalue problem, though the problem is nonconvex. For high-dimensional problems, direct computation of the solution is expensive, and we develop an efficient algorithm for this case. One appealing feature of the proposed frame- work is that it includes several well-known algorithms as special cases, thus elucidating their intrinsic relationships. We further show that the proposed framework can be extended to the kernel-induced feature space. We have conducted extensive experiments on multi-topic web page categorization and automatic gene expression pattern image annotation tasks, and results demon- strate the effectiveness of the proposed formulation in comparison with several representative algorithms.
根据提供的引用内容,可以了解到Multi-label focal dice loss是多标签分类问题中的一种损失函数,结合了focal loss和dice loss的特点。下面是Multi-label focal dice loss的实现代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class MultiLabelFocalDiceLoss(nn.Module): def __init__(self, gamma=2, alpha=None, size_average=True): super(MultiLabelFocalDiceLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim() > 2: input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W input = input.transpose(1, 2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C if target.dim() == 4: target = target.view(target.size(0), target.size(1), -1) # N,C,H,W => N,C,H*W target = target.transpose(1, 2) # N,C,H*W => N,H*W,C target = target.contiguous().view(-1, target.size(2)) # N,H*W,C => N*H*W,C elif target.dim() == 3: target = target.view(-1, 1) else: target = target.view(-1) target = target.float() # focal loss logpt = F.log_softmax(input, dim=1) logpt = logpt.gather(1, target.long().view(-1, 1)) logpt = logpt.view(-1) pt = logpt.exp() if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.long().data.view(-1)) logpt = logpt * at loss = -1 * (1 - pt) ** self.gamma * logpt # dice loss smooth = 1 input_soft = F.softmax(input, dim=1) iflat = input_soft.view(-1) tflat = target.view(-1) intersection = (iflat * tflat).sum() A_sum = torch.sum(iflat * iflat) B_sum = torch.sum(tflat * tflat) dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth) loss += (1 - dice) if self.size_average: return loss.mean() else: return loss.sum() ``` 其中,focal loss和dice loss的实现都在forward函数中。在这个函数中,首先将输入和目标数据进行处理,然后计算focal loss和dice loss,并将它们相加作为最终的损失函数。需要注意的是,这里的输入和目标数据都是经过处理的,具体处理方式可以参考代码中的注释。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FocusOneThread

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值