处理数据长尾分布:Focal loss和E focal loss(加权)【即插即用】

文章介绍了FocalLoss,一种针对多分类问题的损失函数,特别考虑了样本不平衡和困难样本的处理。它包括MultiClassFocalLossWithAlpha类,具有可调整的类别权重和gamma参数,以及EqualizedFocalLossCW类,引入了梯度加权机制以增强模型性能。
摘要由CSDN通过智能技术生成

Focal loss

class MultiClassFocalLossWithAlpha(nn.Module):
    def __init__(self, alpha=[1.-8434/28760, 1.-8069/28760, 1.-578/28760, 1.-3903/28760, 1.-7806/28760], gamma=2, reduction='mean'):
        """
        :param alpha: 权重系数列表,三分类中第0类权重0.2,第1类权重0.3,第2类权重0.5
        :param gamma: 困难样本挖掘的gamma
        :param reduction:
        """
        super(MultiClassFocalLossWithAlpha, self).__init__()
        self.alpha = torch.tensor(alpha)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, pred, target):
        alpha = torch.index_select(self.alpha.to(target.device), 0, target.view(-1))  # 为当前batch内的样本,逐个分配类别权重,shape=(bs), 一维向量
        log_softmax = torch.log_softmax(pred, dim=1) # 对模型裸输出做softmax再取log, shape=(bs, 3)
        logpt = torch.gather(log_softmax, dim=1, index=target.view(-1, 1))  # 取出每个样本在类别标签位置的log_softmax值, shape=(bs, 1)
        logpt = logpt.view(-1)  # 降维,shape=(bs)
        ce_loss = -logpt  # 对log_softmax再取负,就是交叉熵了
        pt = torch.exp(logpt)  #对log_softmax取exp,把log消了,就是每个样本在类别标签位置的softmax值了,shape=(bs)
        focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss  # 根据公式计算focal loss,得到每个样本的loss值,shape=(bs)
        if self.reduction == "mean":
            return torch.mean(focal_loss)
        if self.reduction == "sum":
            return torch.sum(focal_loss)
        return focal_loss

E focal loss

class EqualizedFocalLossCW:
    def __init__(self, alpha=[1.-8434/28760, 1.-8069/28760, 1.-578/28760, 1.-3903/28760, 1.-7806/28760], gamma_b=2, scale_factor=8, reduction="mean"):
        self.gamma_b = gamma_b
        self.scale_factor = scale_factor
        self.reduction = reduction
        self.alpha = alpha

    def __call__(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction="none")
        outputs = F.cross_entropy(logits, targets)  # 求导使用,不能带 reduction 参数
        log_pt = -ce_loss
        pt = torch.exp(log_pt)  # softmax 函数打分

        targets = targets.view(-1, 1)  # 多加一个维度,为使用 gather 函数做准备
        grad_i = torch.autograd.grad(outputs=-outputs, inputs=logits)[0]  # 求导
        grad_i = grad_i.gather(1, targets)  # 每个类对应的梯度
        pos_grad_i = F.relu(grad_i).sum()
        neg_grad_i = F.relu(-grad_i).sum()
        neg_grad_i += 1e-9  # 防止除数为0
        grad_i = pos_grad_i / neg_grad_i
        grad_i = torch.clamp(grad_i, min=0, max=1)  # 裁剪梯度

        dy_gamma = self.gamma_b + self.scale_factor * (1 - grad_i)
        dy_gamma = dy_gamma.view(-1)  # 去掉多的一个维度
        # weighting factor
        wf = dy_gamma / self.gamma_b
        weights = wf * (1 - pt) ** dy_gamma

        if self.alpha is not None:
            # Apply class-specific weights
            alpha_tensor = torch.tensor(self.alpha, device=logits.device)
            alpha_weights = alpha_tensor[targets.view(-1)]
            weights = weights * alpha_weights

        efl = weights * ce_loss

        if self.reduction == "sum":
            efl = efl.sum()
        elif self.reduction == "mean":
            efl = efl.mean()
        else:
            raise ValueError(f"reduction '{self.reduction}' is not valid")
        return efl

多类别焦点损失(multi-class focal loss)是一种用于解决多类别不平衡问题的损失函数。它是在类别级别上对长尾数据进行平衡,并挖掘难分类数据的一种方法。与传统的交叉熵损失函数相比,多类别焦点损失更加关注难以分类的样本,通过对误分类样本施加更大的惩罚,以提高模型对于难分类样本的学习能力。 多类别焦点损失的核心思想是引入焦点因子(focal factor),用于调整不同类别样本的权重。焦点因子可以根据样本的难易程度进行动态调整,对于容易分类的样本,焦点因子较小,对于难分类的样本,焦点因子较大。这样可以使模型更加关注难以分类的样本,提高模型对于少数类别的学习效果。 多类别焦点损失的具体计算方式可以参考类别级别的焦点损失(focal loss),通过对每个类别的损失进行加权实现对尾部类别上过量负样本梯度的抑制,并对误分类样本进行惩罚。同时,可以结合其他的损失函数,如GIoU损失或Triple Loss,来进一步提升模型的性能。 总之,多类别焦点损失是一种用于解决多类别不平衡问题的损失函数,通过对难以分类的样本进行加权和惩罚,提高模型对于少数类别的学习效果。 #### 引用[.reference_title] - *1* *3* [多标签分类问题的损失函数与长尾问题](https://blog.csdn.net/bigtailhao/article/details/121015794)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [YoLo系列、SoftNMS、FasterRCNN、DETR系列、GIoU、Dice、GLIP、Kosmos系列、Segment Anything](https://blog.csdn.net/taoqick/article/details/131842147)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值