1. intro
Focal loss主要是为了解决样本不均衡问题,该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
2. 基本原理
2.1 二分类损失函数
L = − y l o g y ′ − ( 1 − y ) log ( 1 − y ′ ) = { − log y ′ y = 1 − log ( 1 − y ′ ) , y = 0 \mathrm{L}=-\mathrm{ylogy}^{\prime}-(1-y) \log \left(1-y^{\prime}\right)=\left\{\begin{array}{ll} -\log y^{\prime} & y=1 \\ -\log \left(1-y^{\prime}\right), & y=0 \end{array}\right. L=−ylogy′−(1−y)log(1−y′)={−logy′−log(1−y′),y=1y=0
y ′ y^{\prime} y′ 是sigmoid 函数的输出, 值再 0-1 之间,。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。
其实, sigmoid 函数输出 y ′ y^{\prime} y′ 在0-1 之间,令 y ′ > 0.5 y^{\prime} > 0.5 y′>0.5时属于正类,小于0.5时为负类,所以当样本属于正类时,其输出越大,即置信度越大,越可能是正类,所以损失函数越小,同理,对于负类输出概率越小,损失函数越小
2.2 Focal Loss
上文中二分类交叉熵损失为
C
E
(
p
,
y
)
=
{
−
log
(
p
)
if
y
=
1
−
log
(
1
−
p
)
otherwise
\mathrm{CE}(p, y)=\left\{\begin{array}{ll} -\log (p) & \text { if } y=1 \\ -\log (1-p) & \text { otherwise } \end{array}\right.
CE(p,y)={−log(p)−log(1−p) if y=1 otherwise
其中
p
∈
[
0
,
1
]
p\in[0, 1]
p∈[0,1]为模型预测正例概率值,令
p
t
=
{
p
if
y
=
1
1
−
p
otherwise
p_{\mathrm{t}}=\left\{\begin{array}{ll} p & \text { if } y=1 \\ 1-p & \text { otherwise } \end{array}\right.
pt={p1−p if y=1 otherwise
所以:
C
E
(
p
,
y
)
=
C
E
(
p
t
)
=
−
log
p
t
C E(p, y)=C E\left(p_{t}\right)=-\log p_{t}
CE(p,y)=CE(pt)=−logpt
Focal Loss 在交叉熵损失上增加一个调节因子
(
1
−
p
t
)
γ
\left(1-p_{t}\right)^{\gamma}
(1−pt)γ, FL的定义如下:
F
L
(
p
t
)
=
−
(
1
−
p
t
)
γ
log
p
t
F L\left(p_{t}\right)=-\left(1-p_{t}\right)^{\gamma} \log p_{t}
FL(pt)=−(1−pt)γlogpt
当
p
t
p_t
pt很小时, 调节因子值很接近1, loss不受影响, 当
p
t
p_t
pt趋于1时, 调节因子接近0, 这样已经能正确分类的简单样例 loss 大大降低。超参数
γ
\gamma
γ 为0时,FL等价于CE,论文中发现取2
时是最好的,此时若一个样本的
p
t
p_t
pt 为0.9,其对应的CE loss是FL的100倍,可见FL相比CE可以大大降低简单例子的loss,使模型训练更关注于难例。
举例:
例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的
γ
\gamma
γ 次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效
此外,加入平衡因子alpha,用来平衡正负样本本身的比例不均:文中 α \alpha α 取0.25,即正样本要比负样本占比小,这是因为负例易分。
L f l = { − α ( 1 − p ) γ log p y = 1 − ( 1 − α ) p γ log ( 1 − p ) , y = 0 L_{f l}=\left\{\begin{array}{ll} -\alpha\left(1-p\right)^{\gamma} \log p & y=1 \\ -(1-\alpha) p^{ \gamma} \log \left(1-p\right), & y=0 \end{array}\right. Lfl={−α(1−p)γlogp−(1−α)pγlog(1−p),y=1y=0
添加 α \alpha α可以平衡正负样本的重要性,但是无法解决简单与困难样本的问题。
γ \gamma γ 调节简单样本权重降低的速率,当 γ \gamma γ 为0时即为交叉熵损失函数,当 γ \gamma γ 增加时,调整因子的影响也在增加。实验发现 γ \gamma γ 为2是最优。
3. 程序实现
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = -1,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss