【PyTorch】Balanced_CE_loss 实现

Pytorch中Balance binary cross entropy自定义实现

balance binary cross entropy损失函数在分割任务中很有用,因为分割任务会遇到正负样本不均的问题,甚至在边缘的分割任务重,样本不均衡达到了很高的比例。

故此,个人在基于分割任务中,自实现了该损失函数,亲测有效!

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss


def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
    # element-wise losses
    loss = F.cross_entropy(pred, label, reduction='none')
    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
    return loss


def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1

    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)

    return bin_labels, bin_label_weights


def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None):
    
    if pred.dim() != label.dim():
        label, weight = _expand_binary_labels(label, weight, pred.size(-1))

    # weighted element-wise losses
    if weight is not None:
        weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(pred, label.float(), weight, reduction='none')
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)

    return loss


def balanced_mask_cross_entropy(pred, label, mask=None, negative_ratio=3.0, eps=1e-10):
    positive = label.byte()
    negative = (1-label).byte()
    positive_count = int(positive.float().sum())
    negative_count = min(int(negative.float().sum()), int(positive_count * negative_ratio))
    loss = F.binary_cross_entropy(pred, label, reduction='none')[:,0,:,:]
    positive_loss = loss * positive.float()
    negative_loss = loss * negative.float()
    negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)

    balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + eps)
    return balance_loss



@LOSSES.register_module()
class BalancedCrossEntropyLoss(nn.Module):

    def __init__(self,
                 negative_ratio=3.0,
                 eps=1e-10,
                 loss_weight=1.0):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.negative_ratio = negative_ratio
        self.eps = eps
        self.loss_weight = loss_weight
        self.cls_criterion = balanced_mask_cross_entropy

    def forward(self,
                pred,
                label,
                mask=None,
                **kwargs):
        
        loss_cls = self.loss_weight * self.cls_criterion(
                pred, label, mask=None, negative_ratio=self.negative_ratio, eps=self.eps, **kwargs
            )

        return loss_cls

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

libo-coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值