原文链接:https://arxiv.org/abs/1905.02249
本文提出的MixMatch方法结合了之前半监督学习中一系列的有效方法,在仅有少量标注的情况下,在很多数据集上都达到了可以媲美有监督学习的结果。
摘要
半监督学习已被证明是利用未标记数据减轻对大型标记数据集依赖的一个有效地方法。在这项工作中,我们统一了目前半监督学习的主要方法,并产生了一个新的算法—MixMatch。该方法主要通过猜测数据增强的无标数据的低熵标签,并使用MixUp混合有标和无标样例。我们展示了MixMatch在许多数据集和标记的数据量上获得了state of the art的结果。例如,在包含250个标签的CIFAR-10上,我们将错误率降低了4倍(从38%降低到11%),在STL-10上降低了2倍。我们还展示了MixMatch如何帮助实现对差异私有性中精确性和私有性权衡。最后,我们进行消融研究,得出哪些成分是MixMatch取得成功的关键。
介绍
1.现有大多数深度网络的成功依赖于大量的有标数据。对于很多任务收集标注数据很困难,而得到无标数据相对容易。
2.半监督学习(SSL)试图利用无标注的数据来减轻对有标数据的需求。很多SSL方法针对无标注数据增加损失项来使得模型很好的泛化到未见数据上。损失项可分为三类:熵最小化,一致性正则化和一般正则化。
本文介绍了MixMatch利用一个loss将这些方法应用到半监督学习中,有以下贡献:
1.在所有数据集上取得了state-of-the-art。
2.消融实验表明MixMatch效果好于每部分之和。
3.MixMatch对于隐私学习很有效。取得state-of-the-art的同时也保证了隐私性。
相关工作
1.一致性正则化
半监督学习中的一致性正则化利用了这样一个假设,分类器对于数据增强后的的数据的分类分布应该与之前的类别分布一样。损失可以写成下式:
∣
∣
P
m
o
d
e
l
(
y
∣
A
u
g
m
e
n
t
(
x
)
;
θ
)
−
P
m
o
d
e
l
(
y
∣
A
u
g
m
e
n
t
(
x
)
;
θ
)
∣
∣
2
2
||P_{model}(y|Augment(x);\theta)-P_{model}(y|Augment(x);\theta)||_2^2
∣∣Pmodel(y∣Augment(x);θ)−Pmodel(y∣Augment(x);θ)∣∣22
数据增强是随机的,所以希望不同数据增强后的同一张图片能尽可能分到同一类。但这类方法的一大缺点是它们只使用了域特定的数据增强方法。MixMatch通过使用图像的标准数据增强来利用一致性正则化(随机水平翻转和裁剪)
2.熵最小化
半监督学习中,一个基本的潜在假设是分类器的决策边界不应该穿过数据边缘分布的高密度区域。为达到这一目的,需要分类器对于无标注样例输出低熵的预测值。通过加入一个损失项来显式的最小化无标输入的熵: P m o d e l ( y ∣ x ; θ ) P_{model}(y|x;\theta) Pmodel(y∣x;θ)。伪标签的方法通过对无标样例高置信度的预测值打上硬标签,并进一步带入一个交叉熵损失训练来实现熵最小化。MixMatch也通过对未标记数据的目标分布使用“sharpening”函数来隐式地实现熵的最小化。
3.传统正则化
正则化是指对模型施加约束,使其更难记忆训练数据,从而有望更好地推广到未见数据的一般方法。一种常用方法是增加损失项来惩罚模型参数的L2范数。本文的优化方法为Adam算法,所以使用权值衰减来替代L2损失项。
最近,MixUp方法同时对输入和标签的凸组合训练一个模型,其要求模型对两个输入的凸组合的输出接近每个单独输入的输出的凸组合。我们也在MixMatch中使用了MixUp方法(用于有标数据)。
MixMatch
MixMatch集成了上述方法,给定有标数据集
X
\mathcal{X}
X和同等大小的无标数据集
U
\mathcal{U}
U,对有标数据和无标数据进行数据增强分别得到
X
’
\mathcal{X^’}
X’和
U
’
\mathcal{U^’}
U’。它们被分别用来计算有标和无标的损失项,最终Loss如下:
X
’
,
U
’
=
M
i
x
M
a
t
c
h
(
X
,
U
,
T
,
K
,
α
)
\mathcal{X^’},\mathcal{U^’}=MixMatch(\mathcal{X},\mathcal{U},T,K,\alpha)
X’,U’=MixMatch(X,U,T,K,α)
L
X
=
1
∣
X
’
∣
∑
x
,
p
∈
X
’
H
(
p
,
P
m
o
d
e
l
(
y
∣
x
;
θ
)
)
L_{\mathcal{X}}=\frac{1}{|\mathcal{X^’}|}\sum_{x,p \in \mathcal{X^’}}H(p,P_{model}(y|x;\theta))
LX=∣X’∣1x,p∈X’∑H(p,Pmodel(y∣x;θ))
L
U
=
1
L
∣
U
’
∣
∑
x
,
p
∈
U
’
∣
∣
q
−
P
m
o
d
e
l
(
y
∣
u
;
θ
)
∣
∣
2
2
L_{\mathcal{U}}=\frac{1}{L|\mathcal{U^’}|}\sum_{x,p \in \mathcal{U^’}}||q-P_{model}(y|u;\theta)||_2^2
LU=L∣U’∣1x,p∈U’∑∣∣q−Pmodel(y∣u;θ)∣∣22
L
=
L
X
+
λ
U
L
U
L=L_{\mathcal{X}}+\lambda_{\mathcal{U}}L_{\mathcal{U}}
L=LX+λULU
其中
H
(
p
,
q
)
H(p,q)
H(p,q)表示分布p和q之间的交叉熵损失,
T
T
T,
K
K
K,
α
\alpha
α,
λ
U
\lambda_{\mathcal{U}}
λU是超参数,整个算法流程如下表所示:
1.数据增强
如上文所说,数据增强是减轻缺少有标数据影响的一种方法。类似于大部分半监督学习方法,我们同时对有标和无标数据进行数据增强。对有标数据进行一次数据增强,无标数据进行K次数据增强。这些无标数据增强后得到的结果进行‘laebl guessing’获得 q b q_b qb
2.label guessing
对于单个无标样例,我们计算K次增强后类别预测分布的均值,这个得到的标签带入后续的无监督损失项中。
q
b
‾
=
1
K
∑
k
=
1
K
p
m
o
d
e
l
(
y
∣
u
b
,
k
^
;
θ
)
\overline {q_b}=\frac{1}{K}\sum_{k=1}^Kp_{model}(y|\hat{u_{b,k}};\theta)
qb=K1k=1∑Kpmodel(y∣ub,k^;θ)
这个方法在一致性正则化方法中很常见。
3.sharpening
得到了上述label guessing的结果后,使用sharpening方法进行熵最小化处理,如下式:
S
h
a
r
p
e
n
(
p
,
T
)
i
:
=
p
i
1
T
/
∑
j
=
1
L
p
j
1
T
Sharpen(p,T)_i := p_i^{\frac{1}{T}}/\sum_{j=1}^{L}p_j^{\frac{1}{T}}
Sharpen(p,T)i:=piT1/j=1∑LpjT1
其中p是类别分布(上述的增强后类别分布的均值),T是超参数。T越趋于0,sharpen的输出就趋向于one-hot。因为后续我们需要使用sharpen的输出作为模型预测的目标值,所以选择较低的T保证了模型可以产生低熵的预测。
4.MixUp
我们同时对有标数据和有label guessing结果的无标数据进行MixUp。我们一开始分别对有标数据和无标数据设置不同的loss,但是这会带来问题。对于一对样例, ( x 1 , p 1 ) (x_1,p_1) (x1,p1), ( x 2 , p 2 ) (x_2,p_2) (x2,p2)我们稍微修改了MixUp方法。通过下式计算得到 ( x ’ , p ’ ) (x^’,p_’) (x’,p’)
λ
∼
B
e
t
a
(
α
,
α
)
\lambda \sim {Beta(\alpha,\alpha)}
λ∼Beta(α,α)
λ
’
=
m
a
x
(
λ
,
1
−
λ
)
\lambda^’=max(\lambda,1-\lambda)
λ’=max(λ,1−λ)
x
’
=
λ
’
x
1
+
(
1
−
λ
’
)
x
2
x^’=\lambda^’x_1+(1-\lambda^’)x_2
x’=λ’x1+(1−λ’)x2
p
’
=
λ
’
p
1
+
(
1
−
λ
’
)
p
2
p^’=\lambda^’p_1+(1-\lambda^’)p_2
p’=λ’p1+(1−λ’)p2
传统的MixUp可以被看做省略了第二项,即
λ
=
λ
’
\lambda=\lambda^’
λ=λ’。收集所有的有标和无标和label guessing结果使用MixUp。
我们将两部分串联起来并shuffle形成MixUp所需的数据源,对第i个有标样例,计算 M i x U p ( X ^ i , W i ) MixUp(\hat \mathcal{X}_i,W_i) MixUp(X^i,Wi)并加入 X ′ \mathcal{X'} X′集合中。由于我们的修改,MixUp的结果应该更接近原始有标数据而不是插值的结果。用剩余的W来计算 U ′ \mathcal{U}' U′
据此,MixMatch将 X \mathcal{X} X转变为了 X ′ \mathcal{X}' X′,一个包含数据增强后的有标数据和与无标数据MixUp结果的集合。相应的, U \mathcal{U} U转变为了 U ′ \mathcal{U}' U′,一个对于每个无标样例进行多重数据增强并包含其label guessing的集合。
5.损失函数
获得了 X ′ \mathcal{X}' X′和 U ′ \mathcal{U}' U′之后,利用本节一开始的损失函数,对于有标数据,使用传统交叉熵损失,并加上对于 U ′ \mathcal{U}' U′中无标数据的标签预测值的平方L2损失。相较于交叉熵,平方L2损失对错分样例有着更低的敏感性。我们不通过猜测的标签传播梯度。
实验部分感兴趣的读者可以参考原文,这里不再赘述。