对标签噪声鲁棒的广义交叉熵损失 (Generalized Cross Entropy)

在人工智能算法的实际应用场景中,不可避免地会出现训练数据误标现象,即训练数据集上存在标签噪声。这会降低所训模型的泛化能力。尤其是对于深度神经网络这种描述能力极强的模型,标签噪声对推理精度的影响甚或是灾难性的。论文证实了一个简单的两层网络就能记住所有随机分配的标签。

本篇博文将介绍一种对标签噪声鲁棒的损失函数,即General Cross Entropy (GCE)。这种损失函数在2018年的NIPS会议论文中被提出,其集成了Mean Absolute Error (MAE)损失函数的噪声鲁棒性,以及传统的Cross Entropy损失函数的训练高效性。

博文主要内容如下:第一部分将论述MAE为何是噪声鲁棒的。第二部分将介绍GCE的具体表达;其被称为广义交叉熵的原因也将在这一部分揭示。第三部分给出了复现结果。最后将给出相关的参考链接。

一、MAE为何是噪声鲁棒的

一句话解释,因为MAE是一种Symmetric Loss。那么原来的一个问题就变成了两个问题,(1) 什么是Symmetric Loss;(2) 为什么Symmetric Loss是噪声鲁棒的。

为了讨论第一个问题,我们考虑一个十分简单的情形:只包含一个训练样本的二分类问题。假设训练样本为 { x , y } \{x, y\} {x,y},其中 y ∈ { 0 , 1 } y\in\{0, 1\} y{0,1} f θ ( x ) f_\theta(x) fθ(x)为模型输出, θ \theta θ为待优化的参数,损失函数为 l l l。没有标签噪声的情况下,待优化的目标为 l [ f θ ( x ) , y ] l[f_\theta(x),y] l[fθ(x),y]。考虑存在标签噪声的情形,那么样本 x x x有一定概率 ρ \rho ρ被误标为 1 − y 1-y 1y,那么实际上的优化目标是 ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y] (1ρ)l[fθ(x),y]+ρl[fθ(x),1y]
如果

arg min ⁡ θ l [ f θ ( x ) , y ] = arg min ⁡ θ { ( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] } ( 1 ) \argmin_\theta l[f_\theta(x),y]=\argmin_\theta \{(1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]\}\qquad\qquad (1) θargminl[fθ(x),y]=θargmin{(1ρ)l[fθ(x),y]+ρl[fθ(x),1y]}(1)

那么意味着无论有无噪声,该优化问题都会得到同样的解。这时候损失函数 l l l就是噪声鲁棒的。

( 1 − ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ l [ f θ ( x ) , 1 − y ] = ( 1 − 2 ρ ) ⋅ l [ f θ ( x ) , y ] + ρ ⋅ { l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] } ( 2 ) (1-\rho)\cdot l[f_\theta(x),y]+\rho \cdot l[f_\theta(x),1-y]=(1-2\rho)\cdot l[f_\theta(x),y]+\rho\cdot\{l[f_\theta(x),y]+l[f_\theta(x),1-y]\}\qquad\qquad(2) (1ρ)l[fθ(x),y]+ρl[fθ(x),1y]=(12ρ)l[fθ(x),y]+ρ{l[fθ(x),y]+l[fθ(x),1y]}(2)

整理后的第一项是无噪声情况下的优化目标的一个固定倍数。而第二项是当样本标签等概率取遍所有可能值时,所产生的损失值。

敲黑板,如果

l [ f θ ( x ) , y ] + l [ f θ ( x ) , 1 − y ] = C ( 3 ) l[f_\theta(x),y]+l[f_\theta(x),1-y]=C\qquad\qquad (3) l[fθ(x),y]+l[fθ(x),1y]=C(3)

其中 C C C是常数时,那么 l l l就是Symmetric Loss。这时候,(2)式右端的第二项对 θ \theta θ的优化不造成影响,所以(1)式就会成立,所以Symmetric Loss是对噪声鲁棒的。这里的Symmetric是一种轮换对称的含义,就是指当样本等概率取遍所有可能标签时,产生的损失值是常值。

MAE,考虑到 f θ ( x ) ∈ [ 0 , 1 ] f_\theta(x)\in[0,1] fθ(x)[0,1],那么
∣ y − f θ ( x ) ∣ + ∣ 1 − y − f θ ( x ) ∣ = ∣ 0 − f θ ( x ) ∣ + ∣ 1 − f θ ( x ) ∣ = 1 |y-f_\theta(x)|+|1-y-f_\theta(x)|=|0-f_\theta(x)|+|1-f_\theta(x)|=1 yfθ(x)+1yfθ(x)=0fθ(x)+1fθ(x)=1
所以MAESymmetric Loss,所以是噪声鲁棒的。

更严谨的论述,可以参考论文Making Risk Minimization Tolerant to Label Noise

二、GCE

在明了MAE为什么对噪声鲁棒的机理后,我们再来看看Cross Entropy Loss为什么是不鲁棒的。如果还是用一句话来解释,就是因为CE是无界的。考虑(3)式所表达的鲁棒性条件,各个损失项非负,且其和为定值,那么各个损失项必然是有界的。但在CE中,假设 f θ ( x ) f_\theta(x) fθ(x)表示样本 x x x属于类别1的概率,那么损失值是 − log ⁡ f θ ( x ) -\log f_\theta(x) logfθ(x) − log ⁡ ( 1 − f θ ( x ) ) -\log(1-f_\theta(x)) log(1fθ(x))。当 f θ ( x ) f_\theta(x) fθ(x)接近于0或者1时,损失值会非常大。这意味着模型会花费更多的功夫在这些样本上。

这一特性是个双刃剑。如果我们所有的训练样本的标注都是正确的,那么这一特性使得模型关注于那些容易被错判的样本,所以训练过程会十分高效。但如果训练样本存在标签噪声,这一特性将使得模型过度关注于误标的样本,以致最终会得到一个过拟合到误标样本的模型。

好,那我们直接使用MAE不就行了么。Hmmm…实际上MAE实战效果并不好。因为MAE是平等地对待各个样本的,所以其收敛速度比较慢 (具体的解释可以参考GCE论文)。

那么一个自然的想法就是能不能在MAE的噪声鲁棒性和CE的快速收敛性之间做一个折中。这就是GCE的基础想法。具体做法是将CE中的 − log ⁡ ( f θ ( x ) ) -\log(f_\theta(x)) log(fθ(x))项,替换成一个指数项 1 − f θ q ( x ) q \frac{1-f_\theta^q(x)}{q} q1fθq(x),其中幂次 q q q为一个超参数,取值范围为 ( 0 , 1 ] (0,1] (0,1]。当 q = 1 q=1 q=1时,该指数项就蜕变为MAE的形式;当 q → 0 q\rightarrow 0 q0时,由洛必达法则,该指数项将蜕变为CE的形式。所以 q q q控制着在MAECE之间的折中程度。原论文中仅给出了 q q q的经验值,并未给出具体的选取方法。

为进一步抑制误标数据的影响,原论文还给出了一种加强版的损失函数,该函数融合了样本选择的思想。其核心想法是,当训练到一定程度后,模型已经获取了关于数据的正确分类的模式信息,此时,如果模型在样本上输出的概率值比较小,那么这些样本很可能是误标样本。因此,若将这些样本从训练数据中剔除,那么接下来的训练可能会更准确。具体做法是:假设在训练过程中validation数据集上准确率最好的模型为 M b e s t \mathcal{M}_{best} Mbest,则在模型训练一定次数后,每隔一定数目的epoch,将 M b e s t \mathcal{M}_{best} Mbest预测的概率值小于 k k k的样本丢弃,而在其余的数据上进行训练。这里又引入了一个新的超参数 k k k。同样的,原论文只给出了一些数据集上的经验值,并未给明确的设置方法。

三、Results

以下为在CIFAR-10数据集上,传统的交叉熵损失与广义交叉熵损失的结果对比。所用参数与论文中一致:training/validation/testing数据集大小分别为45000/5000/10000;模型选用ResNet34,注意其中第一个卷积层的参数为kernel_size=3, stride=1,而用于ImageNet数据集时kernel_size=7, stride=2,这是因为CIFAR-10的图片尺寸比较小,另外去除了第一个pooling层; q = 0.7 , k = 0.5 q=0.7, k=0.5 q=0.7,k=0.5;总共训练120个epochs,初始学习速率为1E-2,第40、第80epoch学习速率递降10倍;从第40epoch开始,每10epoch筛选一遍数据。每组实验运行了5次,结果表示为分类准确率的 μ ± σ \mu\pm\sigma μ±σ的形式。复现结果与论文所示大致相当,甚至还要略优于论文中的结果。其中可以看到,存在pair wise类型的标签噪声时,准确率的提升不是特别明显。

表一. 不同方法的对比

参考

  • 25
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值