OhemCELoss用于图像分割的OHEM交叉熵损失函数


前言

OhemCELoss(Online Hard Example Mining Cross Entropy Loss)是一种在深度学习模型训练中应对类别不平衡问题的损失函数,通过在线困难样本挖掘(Online Hard Example Mining,OHEM)的策略来关注那些难以分类的样本,以增强模型对困难样本的学习效果。


一、原理

以像素点为计算单位,计算每个像素点的交叉熵损失,将所有像素的损失由大到小进行排序,根据设置的loss阈值(一般为0.7)取损失值大于阈值的loss作为困难样本, 因为有可能出现没有损失值大于阈值的情况,所以还要设置一个min_kept最少保留数,最少保留数中的最小值和设置的阈值数谁小就设谁为真正的阈值,大于这个值的所有loss才参与之后的训练。

二、步骤

  1. 计算交叉熵损失:对每个像素,计算模型预测值与真实标签的交叉熵损失函数
  2. 将一个batch的损失展平,并由大到小进行排序
  3. 如果最少保留数中的最小值大于阈值,则大于阈值的所有loss都参与计算
  4. 如果最少保留数中的最小值小于阈值,则取前最少保留数个loss都参与计算,这样既保证每个batch都至少有最少保留数个loss参与计算
  5. 最后,计算所有困难样本的平均损失。

三、代码

class OhemCELoss(nn.Module):
    """
    Online hard example mining cross-entropy loss:在线难样本挖掘
    if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
    如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
    那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
    否则,计算前 n_min 个损失:loss = loss[:self.n_min]
    """
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()     # 将输入的概率 转换为loss值
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')   #交叉熵损失
 
    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)    #由大到小排序
        if loss[self.n_min] > self.thresh:       
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

总结

在图像分割问题中,以每个 pixel 的损失为最小单元,而不是 batch 中每张图片。因此排序时需要把 batch 中所有 pixel 拉成一个长向量,再取其中大于阈值的 pixel 作为 hard example。同时,n_min 的设置保证了每个 batch 中都有至少 n_min 个pixel 参与训练,从而一定程度巩固了训练结果,让前向传播不至于空耗。

Reference

https://www.jianshu.com/p/24376b18e5c7

  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值