前言:最近在解决图像类别不平衡的问题,之前介绍了DiceLoss,试了代码虽然又改善但还没解决问题。我要处理图像样本类别属于极度不均衡,了解到FocalLoss也能解决这个问题,于是就想写这篇文章作为记录。
Focal Loss
介绍
解决什么问题:Focal Loss解决的是深度学习遇到类别不平衡的问题,直接用交叉熵损失函数计算损失函数,会使得最终结果偏向于常见类别。
如何解决这个问题:Focal Loss在交叉熵函数的基础上引入了超参数,增大类别少的样本的权重,以及调整易分类样本和困难样本之间的权重关系。
原理和公式
Focal Loss其实是在交叉熵损失函数(Cross Entropy Loss)上改进过来的。
交叉熵损失函数(Cross Entropy Loss):
H
(
y
,
y
^
)
=
−
1
N
∑
i
=
1
N
[
y
i
l
o
g
(
y
^
i
)
+
(
1
−
y
i
)
l
o
g
(
1
−
y
^
i
)
]
H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_ilog(\widehat{y}_i)+(1-y_i)log(1-\widehat{y}_i)]
H(y,y
)=−N1i=1∑N[yilog(y
i)+(1−yi)log(1−y
i)]
这是一个二分类的CE公式,其中y是真实标签,
y
^
\widehat{y}
y
是预测值,N是样本的数量。
原理上,每个样本都会计算一个损失,然后对所有样本的损失求平均。
对于图像来说,这里的N可以看作是图像像素点的个数,
y
^
\widehat{y}
y
是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。
这个CE公式也可以写成如下形式:
C
E
(
p
t
)
=
−
l
o
g
(
p
t
)
CE(pt)=-log(pt)
CE(pt)=−log(pt)
p
t
=
{
p
,
y
=
1
1
−
p
,
o
t
h
e
r
w
i
s
e
p_t= \begin{cases} \ p, & y=1\\ \ 1-p, & otherwise \end{cases}
pt={ p, 1−p,y=1otherwise
p
t
p_t
pt表示预测值和真实值之间的差。
Focal Loss公式:
在CE的基础上引入了超参数
γ
\gamma
γ和
α
\alpha
α,每个样本的损失构成了如下公式:
F
L
(
p
t
)
=
−
α
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
=
α
(
1
−
p
t
)
γ
C
E
(
p
t
)
FL(p_t)=-\alpha(1-p_t)^\gamma log(p_t) =\alpha(1-p_t)^\gamma CE(pt)
FL(pt)=−α(1−pt)γlog(pt)=α(1−pt)γCE(pt)
其中
p
t
p_t
pt是该样本某个类别的预测值,Focal Loss类别一般采用one-hot编码;
α
\alpha
α是给不同类别样本加的权重,对于正样本比较少,就可以加大权重;
γ
\gamma
γ的作用在于如果当前样本预测值
p
t
p_t
pt比较大,就是易分类样本,就会使得
(
1
−
p
t
)
γ
(1-p_t)^\gamma
(1−pt)γ减小。
其实也就相当于计算每个样本交叉熵前面加多了两个权重,一个是类别权重,一个是样本难易权重。类别权重可以更重视类别占比小的;样本难以权重可以更加关注困难样本。
所以实际上Focal Loss是解决了两个问题:样本不均+难易样本。
γ \gamma γ和 α \alpha α如何确定
在Focal Loss论文中,作者通过搜索一个范围来确定两个参数的最优解,最后给出的结果是 γ = 2 \gamma=2 γ=2和 α = 0.25 \alpha=0.25 α=0.25。在该论文任务中,正样本是大大少于负样本的,而正样本参数 α = 0.25 \alpha=0.25 α=0.25,负样本参数 α = 0.75 \alpha=0.75 α=0.75,非常反直觉。经过 ( 1 − p t ) γ (1-p_t)^\gamma (1−pt)γ和 p t γ {p_t}^\gamma ptγ后,正负样本之间的形式会逆转,还要通过 α \alpha α给正样本降权。
所以 γ \gamma γ和 α \alpha α的确定更多还是实验经验的结果,没有什么理论上的方法。
代码
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.cross_entropy_loss = CrossEntropyLoss2d()
def forward(self, inputs, targets):
# CE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')# 要求inputs和targets张量形状一样
CE_loss = self.cross_entropy_loss(inputs, targets)# inputs可以是NxCxHxW,targets可以是NxHxW,会自动对其张量
pt = torch.exp(-CE_loss) # 预测正确的概率
F_loss = self.alpha * (1-pt)**self.gamma * CE_loss
if self.reduction == 'mean':
return torch.mean(F_loss)
elif self.reduction == 'sum':
return torch.sum(F_loss)
else:
return F_loss
代码理解上没什么难的,基本就是照着Focal Loss的公式照着写的。