编写对每个类别/实例加权的自定义loss

最近需要一种自定义loss,可以对每个实例的loss进行不同的加权。在网上找到的代码,没有我想要的,因此首先对torch的loss进行了研究。

torch的loss有包装在nn.Module之中的,有在nn.functional之中的,两种的区别就是前者需要torch对参数进行维护,如果没有parameter不用算梯度的话,就是多占了几个参数的区别而已。torch本身的nn.BCEloss就是调用了一个functional叫binary_cross_entropy,代码非常简单。

class BCELoss(_WeightedLoss):
    __constants__ = ['reduction', 'weight']

    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
        super(BCELoss, self).__init__(weight, size_average, reduce, reduction)

    def forward(self, input, target):
        return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)

而torch之中是有两个loss的大类的,分别是_Loss和_WeightedLoss

class _Loss(Module):
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction


class _WeightedLoss(_Loss):
    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)

其中的区别仅仅在于一个weight,这个weight挺鸡肋的,因为他是不同batch之间进行加权。引用torch官方的说明:

在计算loss的时候,对于不同的batch进行了加权。但是一般人谁用得到这个功能……

图中的yn和xn就是每个batch的数据,在我这里,由于我做多标签分类(10标签),就是每个2维的图片(0-1之间的数值),大小是[157,10]。

由于torch使用了矩阵运算,因此就靠一个式子就能计算一张图片的loss。得到的结果是一个[batch,157,10]的结果,然后使用reduction参数,如果是mean就求均值,使用sum就求和,默认求均值。这里的均值不是每一个图片的loss均值,而是每一个点的loss均值,这解释了为什么loss都那么小。

因此,我只需要在求均值之前,不让其求均值,直接得到所有点的loss,然后乘以我的weight矩阵,然后自己再手动求均值就可以了。代码和测试结果如下:

import torch.nn as nn
import torch
import torch.nn.functional as F

class InstanceWeightedBCELoss(nn.Module):
    def __init__(self):

        super().__init__()

    def forward(self, input, target, instance_weight):
        unweightedloss = F.binary_cross_entropy(input, target,  reduction="none")
        weightedloss = unweightedloss * instance_weight
        loss = torch.mean(weightedloss)
        return loss

# test code
Loss = InstanceWeightedBCELoss()
a = torch.rand((1,157,10))
b = torch.rand((1,157,10))
weight = torch.rand(1,157,10)/5+0.8
print("weight",weight)
loss = Loss(a,b,weight)
print("loss",loss)

#result:
"""
weight tensor([[[0.8668, 0.8081, 0.9909,  ..., 0.9227, 0.9832, 0.8670],
         [0.8636, 0.8764, 0.9606,  ..., 0.8045, 0.9127, 0.8106],
         [0.8062, 0.8840, 0.8495,  ..., 0.9784, 0.8501, 0.9222],
         ...,
         [0.9620, 0.8217, 0.9606,  ..., 0.9674, 0.9211, 0.8560],
         [0.9315, 0.8104, 0.9081,  ..., 0.9596, 0.9536, 0.9976],
         [0.9907, 0.8936, 0.9285,  ..., 0.8556, 0.9208, 0.8183]]])
loss tensor(0.9109)
"""

有个实例加权的loss,那么class加权的loss就很容易了。由于我这个class在最后一维,因此由于tensor的广播机制,直接放入10维的tensor即可,代码和结果如下:

import torch.nn as nn
import torch
import torch.nn.functional as F

class ClassWeightedBCELoss(nn.Module):
    def __init__(self):

        super().__init__()

    def forward(self, input, target, class_weight):
        unweightedloss = F.binary_cross_entropy(input, target,  reduction="none")
        weightedloss = unweightedloss * class_weight
        loss = torch.mean(weightedloss)
        return loss


# test code
Loss = ClassWeightedBCELoss()
a = torch.rand((1,157,10))
b = torch.rand((1,157,10))
weight = torch.rand(10)/5+0.8
print("weight",weight)
loss = Loss(a,b,weight)
print("loss",loss)

#result:
"""
weight tensor([0.8591, 0.9359, 0.8241, 0.9032, 0.9452, 0.8897, 0.9429, 0.9727, 0.9250,
        0.8846])
loss tensor(0.9067)
"""

代码其实很好懂,但是我很纳闷为什么torch官方没有这个自带功能,网上也没有找到很好的代码。

谢谢你的阅读,如果有什么问题和想要讨论的欢迎留言。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Qt/C++是一种开发桌面应用程序的跨平台开发框架,它提供了丰富的工具和类库,能够方便快捷地编写定义控件源码。 首先,我们需要创建一个继承自QWidget或QFrame的类来实现自定义控件。在这个类中,我们可以重载一些事件处理函数来实现控件的特定功能,比如绘制事件函数paintEvent()、鼠标事件函数mousePressEvent()等等。通过这些函数,我们可以控制控件的外观、响应用户输入等。 在实现自定义控件的外观时,可以利用Qt提供的各种绘图工具和API。例如,可以使用QPainter类来绘制各种形状、图像、文字等,还可以使用QPen和QBrush类来设置绘制的样式和颜色。通过这些工具,我们可以实现各种个性化的外观效果,如圆角、渐变、阴影等。 对于自定义控件的功能实现,可以根据需求使用Qt提供的各种功能模块。比如,使用QTimer类实现定时器功能,使用QMediaPlayer类实现音视频播放功能等等。此外,Qt还提供了一系列的信号和槽机制,可以方便地实现控件之间的交互和通信。 在自定义控件的使用方面,可以通过在其他QWidget中使用该控件的对象的方式来使用它。将自定义控件放入项目中,然后在界面中添加该控件的实例对象,即可展示该控件,并与其交互。也可以通过在UI界面设计软件中将该控件拖拽到需要的位置上,然后使用信号槽机制来实现与其他控件的交互。 总之,Qt/C++编写定义控件源码需要熟悉Qt的基本概念和API,并结合自身的需求来设计和实现控件的外观和功能。通过合理的设计和编码,可以创建出各种各样的自定义控件,丰富应用程序的界面和功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值