一、Focal Loss公式介绍
1、控制正负样本的权重
2、控制容易分类和难分类样本的权重
论文:
公式:
我们可以利用如下Pt简化交叉熵loss。
此时:
代码:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
-想要降低负样本的影响,可以在常规的损失函数前增加一个系数αt。与Pt类似,当label=1的时候,αt=α;当label=otherwise的时候,αt=1 - α,a的范围也是0到1。此时我们便可以通过设置α实现控制正负样本对loss的贡献。
公式:
其中:
分解开就是:
样本属于某个类,且预测结果中该类的概率越大,其越容易分类 ,在二分类问题中,正样本的标签为1,负样本的标签为0,p代表样本为1类的概率。
对于正样本而言,1-p的值越大,样本越难分类。
对于负样本而言,p的值越大,样本越难分类。
Pt的定义如下
所以利用1-Pt就可以计算出每个样本属于容易分类或者难分类。
具体实现方式如下。
通过如下公式就可以实现控制正负样本的权重和控制容易分类和难分类样本的权重。
分解开就是:
二、Focal Loss代码实现
import torch
import torch.nn as nn
import torch.functional as F
class WeightedFocalLoss(nn.Module):
"Non weighted version of Focal Loss"
def __init__(self, alpha=.25, gamma=2):
super(WeightedFocalLoss, self).__init__()
# --------------#
# 平衡正负样本系数
# --------------#
self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
# --------------#
# 平衡难易样本系数
# --------------#
self.gamma = gamma
def forward(self, inputs, targets):
# --------------#
# 分类交叉熵损失
# --------------#
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
# --------------#
# 标签GT
# --------------#
targets = targets.type(torch.long)
# --------------#
# 计算at
# --------------#
at = self.alpha.gather(0, targets.data.view(-1))
# --------------#
# 计算pt: BEC_loss = -log(pt) --> pt = torch.exp(-BCE_loss)
# --------------#
pt = torch.exp(-BCE_loss)
# --------------#
# 计算Focal Loss
# --------------#
F_loss = at*(1-pt)**self.gamma * BCE_loss
return F_loss.mean()
说明