图像分割损失函数OhemCELoss

OhemCELoss函数简介

OhemCELoss函数( Online hard example mining cross-entropy loss 的缩写)

分割任务中的OhemCELoss函数:其实就是分类任务的交叉熵函数—>每个像素计算分类交叉熵---->根据loss选取难样本,一步一步扩展得到。

在语义分割网络中常用的损失函数,这里大概记录几个需要留意的点:

1)计算交叉熵损失时,是以一个像素点为计算单位,计算出每个像素点的交叉熵分类损失。

2)ohem难样本挖掘时,根据给定的阈值选取前n_min个像素点的loss值。

源码:

使用pytorch框架OhemCELoss函数的代码实现

class OhemCELoss(nn.Module):
    """
    Online hard example mining cross-entropy loss:在线难样本挖掘
    if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
    如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
    那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
    否则,计算前 n_min 个损失:loss = loss[:self.n_min]
    """
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()     # 将输入的概率 转换为loss值
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')   #交叉熵

    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)     # 排序
        if loss[self.n_min] > self.thresh:       # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

一篇讲述很详细的博客

详细的内容直接参见下面的博客就好,不必重复码字造车。

图像分割 OhemCELoss - 简书icon-default.png?t=M0H8https://www.jianshu.com/p/24376b18e5c7

  • 6
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值