Focal loss 中两个加权参数的原理和产生的影响

首先需要明确一个在损失函数中的加权细节:想要在损失函数中对样本进行加权,那么加权的思路应该要是逆向的。因为损失函数的优化目标是越小越好,所以你越想保护的部分应该给予小权重,使得这部分可以大。而越想惩罚的部分,应该给予大权重,这样强制让他们只能是小的。

 

Focal loss :FL(p_t)=\begin{cases} -\alpha (1-p)^{\gamma}log(p) & \text{ if } y=1 \\ -(1-\alpha)p^{\gamma}log(1-p) & \text{ if } otherwise \end{cases} 。里面最核心的两个参数 \alpha 和 \gamma。 

 

其中 \alpha 类似与class weight 给类别加权重。如果 y = 1 类样本个数大于 y = 0, 那么 \alpha 应该小于 0.5,保护样本少的类,而多惩罚样本多的类。结论是样本越不平衡,\alpha 应该越靠近 0 或者 1。

而 \gamma 的作用是竟然把难例分开,这个参数越大,导致的后果是预测的概率值越偏向于0~1的两端。具体推理如下图所示:

 

  • 8
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
以下是 PyTorch 实现 Focal Loss 的代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim() > 2: input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W input = input.transpose(1, 2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1, 1) logpt = F.log_softmax(input) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = logpt.data.exp() if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * at loss = -1 * (1 - pt) ** self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() ``` 可以看到,该代码首先定义了一个 `FocalLoss` 类,该类继承自 PyTorch 的 `nn.Module` 类,因此我们可以直接使用该类来定义我们的 Focal Loss 模型。 在 `__init__()` 方法,我们定义了两个参数 `gamma` 和 `alpha`。其 `gamma` 的值默认为 2,即 Focal Loss 调节因子。`alpha` 表示每个类别的权重,如果 `alpha` 是一个浮点数,则表示正样本的权重,负样本的权重为 1 - `alpha`。如果 `alpha` 是一个列表,则它的长度应该等于类别数,每个元素表示每个类别的权重。 在 `forward()` 方法,我们首先将输入的 `input` 和 `target` 二者都展平成一维向量,然后计算损失函数。具体而言,我们首先对 `input` 进行 softmax 操作,然后取出对应类别的概率值 `pt`,接着根据 `alpha` 权重计算加权的对数概率值 `logpt`。最后根据 Focal Loss 的公式计算损失,并返回平均值或总和。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值