目录
研究背景
论文链接:Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting.
数据不平衡问题在现实世界中非常普遍。对于真实数据,不同类别的数据量一般不会是理想的uniform分布,而往往会是不平衡的;如果按照不同类别数据出现的频率从高到低排序,就会发现数据分布出现一个“长尾巴”,也即我们所称的长尾效应。大型数据集经常表现出这样的长尾标签分布:
一、为什么存在类别不平衡现象?
为什么会存在不平衡的现象?其实很好理解,一个通用的解释就是特定类别的数据是很难收集的。拿Species分类来说,特定种类(如猫,狗等)非常常见,但是有的种类(如高山兀鹫,随便举的例子…)就非常稀有。再比如对自动驾驶,正常行驶的数据会占大多数,而真正发生异常情况/存在车祸危险的数据却极少。再比如对医疗诊断,患有特定疾病的人群数相比正常人群也是极度不平衡的。对于healthcare data来说另一个可能原因是和privacy issue有关,特定病人可能都很难采集数据。
那么,不平衡或长尾数据会有什么问题?简单来说,如果直接把类别不平衡的样本丢给模型用ERM学习,显然模型会在major classes的样本上的学习效果更好,而在minor classes上泛化效果差,因为其看到的major classes的样本远远多于minor classes。
那么,对于不平衡学习问题有哪些解决方法?这里总结了几种方法。
- 重采样(re-sampling):更具体可分为对少样本的过采样1,或是对多样本的欠采样2。但因过采样容易overfit到minor class,无法学到更鲁棒易泛化的特征,往往在非常不平衡数据上表现会更差;而欠采样则会造成major class严重的信息损失,导致欠拟合发生。
- 数据合成(synthetic samples):即生成和少样本相似的“新”数据。经典方法SMOTE3,思路简单来讲是对任意选取的少类样本,用K近邻选取其相似样本,通过对样本线性插值得到新样本。这里会想到和mixup4很相似,于是也有imbalance的mixup版本出现5。
- 数据增强(Easy Data Augmentation):提出并验证了几种加噪的 text augmentation 技巧,分别是同义词替换(SR: Synonyms Replace)、随机插入(RI: Randomly Insert)、随机交换(RS: Randomly Swap)、随机删除(RD: Randomly Delete)用来增加长尾类别数据量6。
- 主动学习(Active Learning):通过模型获取到那些比较“难”分类的样本数据,让人工再次确认和审核,然后将人工标注得到的数据再次使用有监督学习模型或者半监督学习模型进行训练,逐步提升模型的效果,将人工经验融入模型中7。
- 重加权(re-weighting):对不同类别(甚至不同样本)分配不同权重。注意这里的权重可以是自适应的。此类方法的变种有很多,有最简单的按照类别数目的倒数来做加权8,按照“有效”样本数加权9,根据样本数优化分类间距的loss加权10,等等。
- 迁移学习(transfer learning):这类方法的基本思路是对多类样本和少类样本分别建模,将学到的多类样本的信息/表示/知识迁移给少类别使用。代表性文章有1112。
- 度量学习(metric learning):本质上是希望能够学到更好的embedding,对少类附近的boundary/margin更好的建模。有兴趣的同学可以看看1314。
- 元学习/域自适应(meta learning/domain adaptation):分别对头部和尾部的数据进行不同处理,可以去自适应的学习如何重加权15,或是formulate成域自适应问题16。
- 解耦特征和分类器(decoupling representation & classifier):最近的研究发现将特征学习和分类器学习解耦,把不平衡学习分为两个阶段,在特征学习阶段正常采样,在分类器学习阶段平衡采样,可以带来更好的长尾学习结果1718。这也是目前的最优长尾分类算法。
二、 Meta-Weight-Net[NIPS’2019]
本文主要目的是为了介绍用元学习的方法来同时优化噪声标签与类别不平衡的问题。
这篇论文的主要关注点在于解决如何对Loss进行重加权(re-weighting)的问题,在传统机器学习分类任务中,对于有偏置的数据,即对含有incorrect label数据跟长尾类别的数据进行训练时,模型可能会关注到损失较大的数据能否正确分类。
优化这类问题的最朴素的思想包括两个方面:对于误标数据,我们希望淡化其对训练模型时的负面作用;而对于长尾类别来说,我们需要模型预测的更准的话,就需要增强长尾类别对模型训练的作用;
这在直观上的表现就是对Loss进行重加权,对噪声标签,我们希望给一个小的权重,尽量淡化其对模型的副作用,对长尾类别我们希望给一个大的权重,使模型能够更好的预测这部分数据量较小的样本。
1.Focal Loss
由这种思想产生了两种主流的方法来设计weight-function:第一种是样本loss越大,表明其越可能是分类边界上的不确定的比较hard的样本,其weight应该越大,这主要适用于类别不均衡问题,因为它可以prioritize那些有更大loss的minority class;这里面比较具有代表性的就是Focal Loss19。
Focal Loss 主要思想是通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。
F
L
(
p
t
)
=
−
α
t
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
(
1
)
.
FL(p_t) = -\alpha_t(1-p_t)^{\gamma}log(p_t) \qquad \qquad\qquad(1) .
FL(pt)=−αt(1−pt)γlog(pt)(1).
对于上面的公式,其baseline为一个简单的交叉熵损失,通过设定
α
\alpha
α的值来控制正负样本对总的loss的共享权重。
α
\alpha
α取比较小的值来降低负样本(多的那类样本)的权重。这里的
γ
\gamma
γ称作focusing parameter,
γ
>
=
0
\gamma>=0
γ>=0。
(
1
−
p
t
)
γ
(1-p_t)^\gamma
(1−pt)γ称为调制系数(modulating factor)
也就是说在一个well-trained模型中,当一个样本被分错的时候,
p
t
p_t
pt是很小的,那么调制因子
1
−
p
t
1-p_t
1−pt接近1,损失不被影响;当
p
t
p_t
pt→1,因子(1-Pt)接近0,那么分的比较好的(well-classified)样本的权值就被调低了。因此调制系数就趋于1,也就是说相比原来的loss是没有什么大的改变的。当
p
t
p_t
pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。
当
γ
\gamma
γ=0的时候,Focal loss就是传统的二分类交叉熵损失,即
C
E
(
p
,
y
)
=
−
y
l
o
g
(
p
)
−
(
1
−
y
)
l
o
g
(
1
−
p
)
,
其
中
y
=
1
的
概
率
为
p
,
y
=
0
的
概
率
为
1
−
p
CE(p,y)=-ylog(p)-(1-y)log(1-p),其中y=1的概率为p,y=0的概率为1-p
CE(p,y)=−ylog(p)−(1−y)log(1−p),其中y=1的概率为p,y=0的概率为1−p,令
p
t
=
{
p
,
y
=
1
;
1
−
p
,
o
t
h
e
r
w
i
s
e
;
p_t = \left\{ \begin{matrix} p, y=1; \\ 1-p, otherwise; \end{matrix}\right.
pt={p,y=1;1−p,otherwise;
为了表示简便,我们用
p
t
p_t
pt表示样本属于ground truth的概率,则
C
E
(
p
,
y
)
=
−
l
o
g
(
p
t
)
CE(p,y)=-log(p_t)
CE(p,y)=−log(pt)
当
γ
\gamma
γ增加的时候,调制系数也会增加。 专注参数
γ
\gamma
γ平滑地调节了易分样本调低权值的比例。
γ
\gamma
γ增大能增强调制因子的影响,paper实验发现
γ
\gamma
γ取2最好。直觉上来说,调制因子减少了易分样本的损失贡献,拓宽了样例接收到低损失的范围。当γ一定的时候,比如等于2,一样easy example(
p
t
p_t
pt=0.9)的loss要比标准的交叉熵loss小100+倍,当
p
t
=
0.968
p_t=0.968
pt=0.968时,要小1000+倍,但是对于hard example(
p
t
<
0.5
p_t< 0.5
pt<0.5),loss最多小了4倍。这样的话hard example的权重相对就提升了很多。这样就增加了那些误分类的重要性。
2.self-pacd learning
这里介绍第二种方法对噪声标签影响进行优化,样本loss越小,表明其本身的标注结果越clean,其weight应该越大,这主要适用于有很多noisy labels的情况,因为它可以削弱那些loss很大的样本(很可能是incorrect labels)的影响。这里主要的方法为自步学习(self-pacd learning),Koller教授在2010年NIPS上给出了自步学习的数学表达。在自步学习中,引入表示样本是简单还是困难学习的变量 v i v_i vi。网络的权值,以及样本的难易程度就要通过下面的式子来更新:
( W t + 1 , V t + 1 ) = a r g m i n W ∈ R d , V ∈ { 0 , 1 } n ( r ( W ) + ∑ i = 1 n v i f ( x i , y i ; W ) − 1 k ∑ i = 1 n v i ) (W_{t+1},V_{t+1}) =argmin_{ W\in R^d,V\in \{ 0, 1\}^n }(r(W)+\sum_{i=1}^{n}v_if(x_i,y_i;W) - \frac{1}{k}\sum_{i=1}^{n}v_i) (Wt+1,Vt+1)=argminW∈Rd,V∈{0,1}n(r(W)+i=1∑nvif(xi,yi;W)−k1i=1∑nvi)
这里 v i v_i vi取值是 { 0 , 1 } \{0,1\} {0,1} ,当 v i = 0 v_i= 0 vi=0 时,表示简单样本;当 v i = 1 v_i= 1 vi=1 时,表示难样本。用 k k k 控制样本数量,当 k k k较小时, 1 k \frac{1}{k} k1 较大,允许非零的 v i v_i vi 数量较多,即允许模型中难的样本数量增加;反之,当 k k k 较大时, 1 k \frac{1}{k} k1 较小,允许非零的 v i v_i vi数量较少,即允许模型中难的样本数减少。
Focal loss在无噪声标签的数据集上面表现的比较好,但是在实际业务数据集中,并不是只存在这一种情况,多数情况下由于标签类别体系不明确,标签界限模糊,会导致标注人员误标漏标,一般大型数据集上的误标率在10%左右。因此我们需要考虑到噪声标签跟长尾类别同时对模型带来的影响。
3.Meta-Weight-Net
在这篇论文中,作者先对噪声标签与长尾类别分别做了对比实验,由图可以看出:
在focal loss中loss权重随着loss的不断增大而增大,而在SPL中loss权重随着loss的变大而减小,这就是一种矛盾的情况,对于难学习的imbalanced class我们需要给大的权重,而对于noisy label我们需要给一个极小的权重来淡化对模型的影响。
因此作者提出了一种元学习的方法来对损失进行重加权,让神经网络“学习”如何给不同的data loss进行重加权。
上图中,作者设计的元学习网络对于imbalanced factor 100 和 40% uniform noise的数据集进行拟合,得到了如图f所示的函数曲线。
首先看一下网络结构部分,用了一个一层的多层感知机(MLP)来做这个weight function(理论上,一层的MLP就可以近似任何函数了)。其中
V
(
.
;
θ
)
V(.;\theta)
V(.;θ) 代表上述的多层感知机网络,输入一个loss,输出一个weight,
θ
\theta
θ 是网络的参数;
L
i
t
r
a
i
n
(
w
)
{L_i}^{train}(w)
Litrain(w) 表示一个样本的loss,
L
t
r
a
i
n
(
w
;
θ
)
=
1
N
∑
i
=
1
N
V
(
L
i
t
r
a
i
n
(
w
)
;
θ
)
L
i
t
r
a
i
n
(
w
)
{L}^{train}(w;\theta)=\frac{1}{N}\sum_{i=1}^{N}V({L_i}^{train}(w);\theta){L_i}^{train}(w)
Ltrain(w;θ)=N1∑i=1NV(Litrain(w);θ)Litrain(w) 表示所有样本加权的损失,如果
θ
\theta
θ确定,那么
w
w
w 的最优值
w
∗
w^*
w∗ 确定,
w
∗
(
θ
)
=
a
r
g
m
i
n
w
L
t
r
a
i
n
(
w
;
θ
)
w^*(\theta)=argmin_w{L}^{train}(w;\theta)
w∗(θ)=argminwLtrain(w;θ) ,
w
∗
(
θ
)
w^*(\theta)
w∗(θ)是关于
θ
\theta
θ 的函数,因此当训练数据进入模型之后,先固定住
θ
\theta
θ 求
w
^
\hat{w}
w^,这里文章中的
w
∗
w^*
w∗ 并不好求,因此作者用MaML的思想
w
^
\hat{w}
w^ 来近似代替
w
∗
w^*
w∗ ,其中
w
^
\hat{w}
w^为t时刻一次梯度下降的结果。
w
^
m
e
t
a
−
m
o
d
e
l
(
t
)
(
θ
)
=
w
(
t
)
−
α
1
N
∑
i
=
1
N
V
(
L
i
t
r
a
i
n
(
w
)
;
θ
)
∇
w
L
i
t
r
a
i
n
(
w
)
∣
w
(
t
)
(
3
)
\hat{w}^{(t)}_{meta-model}(\theta) = w^{(t)} - \alpha\frac{1}{N}\sum_{i=1}^{N}V({L_i}^{train}(w);\theta)\nabla_w{L_i}^{train}(w)|_{w^{(t)}} \qquad\qquad(3)
w^meta−model(t)(θ)=w(t)−αN1i=1∑NV(Litrain(w);θ)∇wLitrain(w)∣w(t)(3)
这里只对模型参数进行更新而不对weight-function进行更新,目的是为了固定
θ
\theta
θ。
得到了一个batch的近似最优权重之后,作者通过固定住“最优”权重
w
∗
w^*
w∗反过来更新
θ
\theta
θ,即
θ
∗
=
a
r
g
m
i
n
θ
L
(
w
∗
(
θ
)
)
\theta^*=argmin_\theta{L}(w^*(\theta))
θ∗=argminθL(w∗(θ)) 这里除了训练集
{
x
i
,
y
i
}
i
=
1
N
\{x_i,y_i\}^N_{i=1}
{xi,yi}i=1N ,再定义一个meta-data
{
x
i
(
m
e
t
a
)
,
y
i
(
m
e
t
a
)
}
i
=
1
M
\{x^{(meta)}_i,y^{(meta)}_i\}^M_{i=1}
{xi(meta),yi(meta)}i=1M ,假定meta-data具有clean labels并且有balanced data distribution,能够represent the meta-knowledge of ground-truth sample-label distribution,且
M
M
M 是远小于
N
N
N 的,实验中是从valid set sample出一些来作为meta-data。由于假设中认为meta-data是unbiased,所以对meta-data求loss时不需要加权,
L
m
e
t
a
(
w
∗
(
θ
)
)
{L}^{meta}(w^*(\theta))
Lmeta(w∗(θ)) 表示每个meta-data样本的loss(因为如果
θ
\theta
θ 确定,一定是用最优的
w
∗
w^*
w∗ 作为分类器参数),
L
m
e
t
a
(
w
∗
(
θ
)
)
=
1
M
∑
i
=
1
M
L
i
m
e
t
a
(
w
∗
(
θ
)
)
{L}^{meta}(w^*(\theta))=\frac{1}{M}\sum_{i=1}^{M}{L_i}^{meta}(w^*(\theta))
Lmeta(w∗(θ))=M1∑i=1MLimeta(w∗(θ)) 表示所有meta-data的损失。这里
θ
∗
=
a
r
g
m
i
n
θ
L
m
e
t
a
(
w
∗
(
θ
)
)
\theta^*=argmin_\theta{L^{meta}}(w^*(\theta))
θ∗=argminθLmeta(w∗(θ))
目标是为了更新
θ
\theta
θ 求出 t+1 时刻的
θ
∗
\theta^*
θ∗,从而可以根据 t+1 时刻的最优
θ
∗
\theta^*
θ∗ 得到t+1时刻的
w
w
w更新。
θ
(
t
+
1
)
=
θ
(
t
)
−
β
1
M
∑
i
=
1
M
∇
θ
L
i
m
e
t
a
(
w
^
(
t
)
(
θ
)
)
∣
θ
(
t
)
(
9
)
\theta^{(t+1)} = \theta^{(t)} - \beta\frac{1}{M}\sum_{i=1}^{M}\nabla_\theta{L_i}^{meta}(\hat{w}^{(t)}(\theta))|_{\theta^{(t)}} \qquad\qquad(9)
θ(t+1)=θ(t)−βM1i=1∑M∇θLimeta(w^(t)(θ))∣θ(t)(9)
有了当前更新后的参数
θ
t
+
1
\theta^{t+1}
θt+1 ,就可以对函数
L
t
r
a
i
n
(
w
;
θ
)
{L}^{train}(w;\theta)
Ltrain(w;θ) 进行梯度下降,求出更新后的
w
t
+
1
w^{t+1}
wt+1 。
w
m
o
d
e
l
(
t
+
1
)
=
w
m
o
d
e
l
(
t
)
−
α
1
N
∑
i
=
1
N
V
(
L
i
t
r
a
i
n
(
w
(
t
)
)
;
θ
(
t
+
1
)
)
∇
w
L
i
t
r
a
i
n
(
w
)
∣
w
(
t
)
(
5
)
{w}^{(t+1)}_{model} = w^{(t)}_{model} - \alpha\frac{1}{N}\sum_{i=1}^{N}V({L_i}^{train}(w^{(t)});\theta^{(t+1)})\nabla_w{L_i}^{train}(w)|_{w^{(t)}} \qquad\qquad(5)
wmodel(t+1)=wmodel(t)−αN1i=1∑NV(Litrain(w(t));θ(t+1))∇wLitrain(w)∣w(t)(5)
为了方便理解,这里对meta-model跟model做一个解释:meta-model为元学习模型,model为训练模型。
1.每个batch开始前,元学习模型从训练模型中load权重参数,先对一个batch的train-data进行训练,只更新meta-model的权重而不更新model和weight function的权重。
2.再对一个batch的meta-data过meta-model模型根据固定好的
w
w
w 更新weight function的参数,
3.再将train-data数据过model与weight function得到Loss更新model的权重(注意:这里不需要更新weight function的参数,每次只更新一部分参数,其他的部分保持不变)
整个Meta-Weight-Net的伪代码描述:
这里我的个人理解是假设train-data跟meta-data的batch_size一样,由
M
<
<
N
M<<N
M<<N 可知一个epoch中,train-data训练
a
=
N
b
a
t
c
h
s
i
z
e
a = \frac{N}{batchsize}
a=batchsizeN轮,meta-data训练
b
=
M
b
a
t
c
h
s
i
z
e
b = \frac{M}{batchsize}
b=batchsizeM轮,这里b远小于a,因此当b轮完之后,meta-data将会重新投入到这个epoch中进行训练。
try:
inputs_val, targets_val = next(train_meta_loader_iter)
except StopIteration:
train_meta_loader_iter = iter(train_meta_loader)
inputs_val, targets_val = next(train_meta_loader_iter)
inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
y_g_hat = meta_model(inputs_val)
l_g_meta = F.cross_entropy(y_g_hat, targets_val)
prec_meta = accuracy(y_g_hat.data, targets_val.data, topk=(1,))[0]
这里一个epoch近似于MAML中的
T
=
a
b
T = \frac{a}{b}
T=ba 个task训练,从而narrow the generation gap between train-data&meta-data。
上图为整个meta-weight-net的示意图,其中step5、step6、step7对应伪代码中的第五、六、七行。
下图为公式9的反向传播过程的求导计算
这里遵循的是链式求导法则,Loss先对
w
^
\hat{w}
w^进行求导,再由
w
^
\hat{w}
w^对
w
(
t
)
求
导
最
后
对
w^{(t)}求导最后对
w(t)求导最后对meta-weight-net的权重进行求导得到导数,如果忽略
1
M
∑
i
=
1
M
G
i
j
\frac{1}{M}\sum_{i=1}^{M}G_{ij}
M1∑i=1MGij 这一项,那么可以看到上式(12)中每一项的方向都是
V
(
L
i
t
r
a
i
n
(
w
)
;
θ
)
V({L_i}^{train}(w);\theta)
V(Litrain(w);θ) 梯度上升的方向,说明权重
θ
\theta
θ 朝着Loss增大(即model_loss权重增大)的方向学习。同时
1
M
∑
i
=
1
M
G
i
j
\frac{1}{M}\sum_{i=1}^{M}G_{ij}
M1∑i=1MGij 这一项可以理解成第
j
j
j 个训练样本的梯度和mini-batch中所有meta-data样本梯度的相似度的均值,该值越大的话,说明这个学习到的train sample的梯度和meta samples的梯度越相似,ModelLoss也会变大,说明模型会朝着跟meta-data数据分布一致的方向进行学习。
最后看一下paper的实验结果:
可以看到:对比Focal Loss和SPL在CIFAR-10和CIFAR-100上的表现可以看出,除了在imbalance=1的情况下,Focal Loss好了0.02个点之外,其他情况Meta-weight-net效果均好于Focal loss。同样只有在noise = 0%的时候,Focal loss会优于Meta-weight-net,其他情况下都是Meta-weight-net优于SPL和Focal loss。
从混淆矩阵这边来看,跟Baseline相比,除了class 1下降了0.9%,其他类别都有明显的提升。Figure4展示了在噪声比例为0时,Meta-weight-net表现不如Baseline外,其他情况都是优于对比模型的,而且随着噪声比例不断上升,效果也下降得最慢。
最后看一下权重分布情况
可以看出,几乎所有的clean sample的权重都比较大,而noise sample的权重小于clean sample,这说明经过训练的Meta-weight-net能够区分clean和noise的图像。
总结
参考论文与资料:
Samira Pouyanfar, et al. Dynamic sampling in convolutional neural networks for imbalanced data classification. ↩︎
He, H. and Garcia, E. A. Learning from imbalanced data. TKDE, 2008. ↩︎
Chawla, N. V., et al. SMOTE: synthetic minority oversampling technique. JAIR, 2002. ↩︎
mixup: Beyond Empirical Risk Minimization. ICLR 2018. ↩︎
H. Chou et al. Remix: Rebalanced Mixup. 2020. ↩︎
Wei J, Zou K. Eda: Easy data augmentation techniques for boosting performance on text classification tasks[J]. arXiv preprint arXiv:1901.11196, 2019. ↩︎
Aggarwal, Charu C., et al. "Active learning: A survey."Data Classification: Algorithms and Applications. CRC Press, 2014. 571-605. ↩︎
Deep Imbalanced Learning for Face Recognition and Attribute Prediction. TPAMI, 2019. ↩︎
Yin Cui, Menglin Jia, Tsung-Yi Lin, Yang Song, and Serge Belongie. Class-balanced loss based on effective number of samples. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 9268–9277, 2019. ↩︎
Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss. NeurIPS, 2019. ↩︎
Large-scale long-tailed recognition in an open world. CVPR, 2019. ↩︎
Feature transfer learning for face recognition with under-represented data. CVPR, 2019. ↩︎
Range Loss for Deep Face Recognition with Long-Tail. CVPR, 2017. ↩︎
Learning Deep Representation for Imbalanced Classification. CVPR, 2016. ↩︎
Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting. NeurIPS, 2019. ↩︎
Rethinking Class-Balanced Methods for Long-Tailed Recognition from a Domain Adaptation Perspective. CVPR, 2020. ↩︎
BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition. CVPR, 2020. ↩︎
Decoupling representation and classifier for long-tailed recognition. ICLR, 2020. ↩︎
Lin T Y, Goyal P, Girshick R, et al. Focal loss for dense object detection[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2980-2988.
[20]. Paper Reading: Meta-Weight-Net[NIPS’2019]. ↩︎