在人工智能算法的实际应用场景中,不可避免地会出现训练数据误标现象,即训练数据集上存在标签噪声。这会降低所训模型的泛化能力。尤其是对于深度神经网络这种描述能力极强的模型,标签噪声对推理精度的影响甚或是灾难性的。论文证实了一个简单的两层网络就能记住所有随机分配的标签。
本篇博文将介绍一种对标签噪声鲁棒的损失函数,即General Cross Entropy
(GCE
)。这种损失函数在2018年的NIPS
会议论文中被提出,其集成了Mean Absolute Error
(MAE
)损失函数的噪声鲁棒性,以及传统的Cross Entropy
损失函数的训练高效性。
博文主要内容如下:第一部分将论述MAE
为何是噪声鲁棒的。第二部分将介绍GCE
的具体表达;其被称为广义交叉熵的原因也将在这一部分揭示。第三部分给出了复现结果。最后将给出相关的参考链接。
一、MAE
为何是噪声鲁棒的
一句话解释,因为MAE
是一种Symmetric Loss
。那么原来的一个问题就变成了两个问题,(1) 什么是Symmetric Loss
;(2) 为什么Symmetric Loss
是噪声鲁棒的。
为了讨论第一个问题,我们考虑一个十分简单的情形:只包含一个训练样本的二分类问题。假设训练样本为
{
x
,
y
}
\{x, y\}
{x,y},其中
y
∈
{
0
,
1
}
y\in\{0, 1\}
y∈{0,1};
f
θ
(
x
)
f_\theta(x)
fθ(x)为模型输出,
θ
\theta
θ为待优化的参数,损失函数为
l
l
l。没有标签噪声的情况下,待优化的目标为
l
[
f
θ
(
x
)
,
y
]
l[f_\theta(x),y]
l[fθ(x),y]。考虑存在标签噪声的情形,那么样本
x
x
x有一定概率
ρ
\rho
ρ被误标为
1
−
y
1-y
1−y,那么实际上的优化目标是
(
1
−
ρ
)
⋅
l
[
f
θ
(
x
)
,
y
]
+
ρ
⋅
l
[
f
θ
(
x
)
,
1
−
y
]
(1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]
(1−ρ)⋅l[fθ(x),y]+ρ⋅l[fθ(x),1−y]。
如果
arg min θ l [ f θ ( x ) , y ] = arg min θ { ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] } ( 1 ) \argmin_\theta l[f_\theta(x),y]=\argmin_\theta \{(1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]\}\qquad\qquad (1) θargminl[fθ(x),y]=θargmin{(1−ρ)⋅l[fθ(x),y]+ρ⋅l[fθ(x),1−y]}(1)
那么意味着无论有无噪声,该优化问题都会得到同样的解。这时候损失函数 l l l就是噪声鲁棒的。
( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] = ( 1 − 2 ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ { l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] } ( 2 ) (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]=(1-2\rho)\cdot l[f_\theta(x),y]+\rho\cdot\{l[f_\theta(x),y]+l[f_\theta(x),1-y]\}\qquad\qquad(2) (1−ρ)⋅l[fθ(x),y]+ρ⋅l[fθ(x),1−y]=(1−2ρ)⋅l[fθ(x),y]+ρ⋅{l[fθ(x),y]+l[fθ(x),1−y]}(2)
整理后的第一项是无噪声情况下的优化目标的一个固定倍数。而第二项是当样本标签等概率取遍所有可能值时,所产生的损失值。
敲黑板,如果
l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] = C ( 3 ) l[f_\theta(x),y]+l[f_\theta(x),1-y]=C\qquad\qquad (3) l[fθ(x),y]+l[fθ(x),1−y]=C(3)
其中
C
C
C是常数时,那么
l
l
l就是Symmetric Loss
。这时候,(2)式右端的第二项对
θ
\theta
θ的优化不造成影响,所以(1)式就会成立,所以Symmetric Loss
是对噪声鲁棒的。这里的Symmetric
是一种轮换对称的含义,就是指当样本等概率取遍所有可能标签时,产生的损失值是常值。
对MAE
,考虑到
f
θ
(
x
)
∈
[
0
,
1
]
f_\theta(x)\in[0,1]
fθ(x)∈[0,1],那么
∣
y
−
f
θ
(
x
)
∣
+
∣
1
−
y
−
f
θ
(
x
)
∣
=
∣
0
−
f
θ
(
x
)
∣
+
∣
1
−
f
θ
(
x
)
∣
=
1
|y-f_\theta(x)|+|1-y-f_\theta(x)|=|0-f_\theta(x)|+|1-f_\theta(x)|=1
∣y−fθ(x)∣+∣1−y−fθ(x)∣=∣0−fθ(x)∣+∣1−fθ(x)∣=1
所以MAE
是Symmetric Loss
,所以是噪声鲁棒的。
更严谨的论述,可以参考论文Making Risk Minimization Tolerant to Label Noise。
二、GCE
在明了MAE
为什么对噪声鲁棒的机理后,我们再来看看Cross Entropy Loss
为什么是不鲁棒的。如果还是用一句话来解释,就是因为CE
是无界的。考虑(3)
式所表达的鲁棒性条件,各个损失项非负,且其和为定值,那么各个损失项必然是有界的。但在CE
中,假设
f
θ
(
x
)
f_\theta(x)
fθ(x)表示样本
x
x
x属于类别1
的概率,那么损失值是
−
log
f
θ
(
x
)
-\log f_\theta(x)
−logfθ(x)或
−
log
(
1
−
f
θ
(
x
)
)
-\log(1-f_\theta(x))
−log(1−fθ(x))。当
f
θ
(
x
)
f_\theta(x)
fθ(x)接近于0
或者1
时,损失值会非常大。这意味着模型会花费更多的功夫在这些样本上。
这一特性是个双刃剑。如果我们所有的训练样本的标注都是正确的,那么这一特性使得模型关注于那些容易被错判的样本,所以训练过程会十分高效。但如果训练样本存在标签噪声,这一特性将使得模型过度关注于误标的样本,以致最终会得到一个过拟合到误标样本的模型。
好,那我们直接使用MAE
不就行了么。Hmmm…实际上MAE
实战效果并不好。因为MAE
是平等地对待各个样本的,所以其收敛速度比较慢 (具体的解释可以参考GCE
论文)。
那么一个自然的想法就是能不能在MAE
的噪声鲁棒性和CE
的快速收敛性之间做一个折中。这就是GCE
的基础想法。具体做法是将CE
中的
−
log
(
f
θ
(
x
)
)
-\log(f_\theta(x))
−log(fθ(x))项,替换成一个指数项
1
−
f
θ
q
(
x
)
q
\frac{1-f_\theta^q(x)}{q}
q1−fθq(x),其中幂次
q
q
q为一个超参数,取值范围为
(
0
,
1
]
(0,1]
(0,1]。当
q
=
1
q=1
q=1时,该指数项就蜕变为MAE
的形式;当
q
→
0
q\rightarrow 0
q→0时,由洛必达法则,该指数项将蜕变为CE
的形式。所以
q
q
q控制着在MAE
和CE
之间的折中程度。原论文中仅给出了
q
q
q的经验值,并未给出具体的选取方法。
为进一步抑制误标数据的影响,原论文还给出了一种加强版的损失函数,该函数融合了样本选择的思想。其核心想法是,当训练到一定程度后,模型已经获取了关于数据的正确分类的模式信息,此时,如果模型在样本上输出的概率值比较小,那么这些样本很可能是误标样本。因此,若将这些样本从训练数据中剔除,那么接下来的训练可能会更准确。具体做法是:假设在训练过程中validation
数据集上准确率最好的模型为
M
b
e
s
t
\mathcal{M}_{best}
Mbest,则在模型训练一定次数后,每隔一定数目的epoch
,将
M
b
e
s
t
\mathcal{M}_{best}
Mbest预测的概率值小于
k
k
k的样本丢弃,而在其余的数据上进行训练。这里又引入了一个新的超参数
k
k
k。同样的,原论文只给出了一些数据集上的经验值,并未给明确的设置方法。
三、Results
以下为在CIFAR-10
数据集上,传统的交叉熵损失与广义交叉熵损失的结果对比。所用参数与论文中一致:training
/validation
/testing
数据集大小分别为45000
/5000
/10000
;模型选用ResNet34
,注意其中第一个卷积层的参数为kernel_size
=3, stride
=1,而用于ImageNet
数据集时kernel_size
=7, stride
=2,这是因为CIFAR-10
的图片尺寸比较小,另外去除了第一个pooling
层;
q
=
0.7
,
k
=
0.5
q=0.7, k=0.5
q=0.7,k=0.5;总共训练120个epochs
,初始学习速率为1E-2
,第40
、第80
个epoch
学习速率递降10
倍;从第40
个epoch
开始,每10
个epoch
筛选一遍数据。每组实验运行了5次,结果表示为分类准确率的
μ
±
σ
\mu\pm\sigma
μ±σ的形式。复现结果与论文所示大致相当,甚至还要略优于论文中的结果。其中可以看到,存在pair wise
类型的标签噪声时,准确率的提升不是特别明显。

参考
- Generalized Cross Entropy:
GCE
的论文。 - pytorch-Truncated-Loss.git: 基于
pytorch
实现的GCE
。