转:https://zhuanlan.zhihu.com/p/266023273
背景
Focal loss是最初由何恺明提出的,最初用于图像领域解决数据不平衡造成的模型性能问题。本文试图从交叉熵损失函数出发,分析数据不平衡问题,focal loss与交叉熵损失函数的对比,给出focal loss有效性的解释。
交叉熵损失函数
L o s s = L ( y , p ) = − y l o g ( p ) − ( 1 − y ) l o g ( 1 − p ) Loss=L(y, p)=-ylog(p)-(1-y)log(1-p) Loss=L(y,p)=−ylog(p)−(1−y)log(1−p)
p是预测的概率,y是标签,二分类中对应0,1
L c e ( y , p ) = { − l o g ( p ) , i f y = 1 − l o g ( 1 − p ) , i f y = 0 (1) L_{ce}(y,p)=\begin{cases} -log(p), \; if \; y=1 \\ -log(1-p), \; if \;y=0 \end{cases} \tag{1} Lce(y,p)={−log(p),ify=1−log(1−p),ify=0(1)
样本不平衡问题
对于所有样本,损失函数为:
L = 1 N ∑ i = 1 N l ( y , p ) L=\frac{1}{N}\sum^N_{i=1}l(y,p) L=N1i=1∑Nl(y,p)
对于二分类问题,损失函数可以写为:
L = 1 N ( ∑ y i = 1 m − l o g ( p ) + ∑ y i = 0 n − l o g ( 1 − p ) ) L = \frac{1}{N}(\sum^m_{y_i=1}-log(p)+\sum^n_{y_i=0}-log(1-p)) L=N1(yi=1∑m−log(p)+yi=0∑n−log(1−p))
其中m为正样本个数,n为负样本个数,N为样本总数,m+n=N。
当样本分布失衡时,在损失函数L的分布也会发生倾斜,如m<<n时,负样本就会在损失函数占据主导地位。由于损失函数的倾斜,模型训练过程中会倾向于样本多的类别,造成模型对少样本类别的性能较差。
平衡交叉熵函数(balanced cross entropy)
基于样本非平衡造成的损失函数倾斜,一个直观的做法就是在损失函数中添加权重因子,提高少数类别在损失函数中的权重,平衡损失函数的分布。如在上述二分类问题中,添加权重参数 α ∈ [ 0 , 1 ] \alpha \in [0,1] α∈[0,1]和 1 − α 1-\alpha 1−α
L = 1 N ( ∑ y i = 1 m − α l o g ( p ) + ∑ y i = 0 n − ( 1 − α ) l o g ( 1 − p ) ) L=\frac{1}{N}(\sum^m_{y_i=1} - \alpha log(p)+\sum^n_{y_i=0}-(1-\alpha)log(1-p)) L=N1(yi=1∑m−αlog(p)+yi=0∑n−(1−α)log(1−p))
其中 α 1 − α = n m \frac{\alpha}{1-\alpha}=\frac{n}{m} 1−αα=mn,即权重的大小根据正负样本的分布进行设置。
Focal loss
即权重的大小根据正负样本的分布进行设置。
focal loss也是针对样本不均衡问题,从loss角度提供的另外一种解决方法。
focal loss的具体形式为:
L f l = { − ( 1 − p ) γ l o g ( p ) , i f y = 1 − p γ l o g ( 1 − p ) , i f y = 0 (2) L_{fl}=\begin{cases} -(1-p)^{\gamma}log(p), \; if \;y=1 \\ -p^{\gamma} log(1-p), \; if \; y=0\end{cases} \tag{2} Lfl={−(1−p)γlog(p),ify=1−pγlog(1−p),ify=0(2)
令 p t = { p , i f y = 1 1 − p , o t h e r w i s e p_t=\begin{cases} p , \; if \; y=1 \\ 1-p, \; otherwise \end{cases} pt={p,ify=11−p,otherwise
将focal loss表达式(2)统一为一个表达式:
L f l = − ( 1 − p t ) γ l o g ( p t ) (3) L_{fl}=-(1-p_t)^{\gamma}log(p_t) \tag{3} Lfl=−(1−pt)γlog(pt)(3)
同理可将交叉熵表达式(1)统一为一个表达式:
L c e = − l o g ( p t ) (4) L_{ce}=-log(p_t) \tag{4} Lce=−log(pt)(4)
p t p_t pt反映了与ground truth即类别y的接近程度, p t p_t pt 越大说明越接近类别y,即分类越准确。 γ > 0 \gamma > 0 γ>0 为可调节因子。
对比表达式(3)和(4), focal loss相比交叉熵多了一个modulating factor即 ( 1 − p t ) γ (1-p_t)^{\gamma} (1−pt)γ 。对于分类准确的样本 p t p_t pt →1 ,modulating factor趋近于0。对于分类不准确的样本 1− p t p_t pt →1 ,modulating factor趋近于1。即相比交叉熵损失,focal loss对于分类不准确的样本,损失没有改变,对于分类准确的样本,损失会变小。 整体而言,相当于增加了分类不准确样本在损失函数中的权重。
p t p_t pt 也反应了分类的难易程度 , p t p_t pt 越大,说明分类的置信度越高,代表样本越易分; p t p_t pt 越小,分类的置信度越低,代表样本越难分。因此focal loss相当于增加了难分样本在损失函数的权重,使得损失函数倾向于难分的样本,有助于提高难分样本的准确度。focal loss与交叉熵的对比,可见下图:
focal loss vs balanced cross entropy
focal loss相比balanced cross entropy而言,二者都是试图解决样本不平衡带来的模型训练问题,后者从样本分布角度对损失函数添加权重因子,前者从样本分类难易程度出发,使loss聚焦于难分样本。
focal loss为什么有效
focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。
相信很多人会在这里有一个疑问,样本难易分类角度怎么能够解决样本非平衡的问题,直觉上来讲样本非平衡造成的问题就是样本数少的类别分类难度较高。因此从样本难易分类角度出发,使得loss聚焦于难分样本,解决了样本少的类别分类准确率不高的问题,当然难分样本不限于样本少的类别,也就是focal loss不仅仅解决了样本非平衡的问题,同样有助于模型的整体性能提高。
要想使模型训练过程中聚焦难分类样本,仅仅使得Loss倾向于难分类样本还不够,因为训练过程中模型参数更新取决于Loss的梯度。
KaTeX parse error: \tag works only in display equations
如果Loss中难分类样本权重较高,但是难分类样本的Loss的梯度为0,难分类样本不会影响模型学习过程。
对于梯度问题,在focal loss的文章中,也给出了答案。如下图所示为focal loss的梯度示意图。其中 x t = y x x_t=yx xt=yx ,其中 y∈{−1,1} 为类别, p t = σ ( x t ) p_t=\sigma(x_t) pt=σ(xt) ,对于易分样本, x t x_t xt >0 ,即 p t p_t pt >0.5 。由图中可以看出,对于focal loss而言,在 x t x_t xt >0 时 导数很小,趋近于0。因此对于focal loss,导数中易是难分类样本占主导,因此学习过程更加聚焦正在难分类样本。
一点小思考
难分类样本与易分类样本其实是一个动态概念,也就是说 p t p_t pt 会随着训练过程而变化。原先易分类样本即 p t p_t pt 大的样本,可能随着训练过程变化为难训练样本即 p t p_t pt 小的样本。
上面讲到,由于Loss梯度中,难训练样本起主导作用,即参数的变化主要是朝着优化难训练样本的方向改变。当参数变化后,可能会使原先易训练的样本 p t p_t pt 发生变化,即可能变为难训练样本。当这种情况发生时,可能会造成模型收敛速度慢,正如苏剑林在他的文章中提到的那样。
为了防止难易样本的频繁变化,应当选取小的学习率。防止学习率过大,造成 w w w 变化较大从而引起 p t p_t pt 的巨大变化,造成难易样本的改变。