Rethinking Data Augmentation: Self-Supervision and Self-Distillation
Abstract
对进行了数据增强(翻转,裁剪等操作)的增强数据任然使用原始标签时,如果增强数据的分布与原始数据有较大的差距,就会降低网络的准确率。为了解决这个问题,作者提出了简单有效的方法:学习新样本的原始标签和自监督标签的联合分布。为了提高训练速度,又引入了知识传播技术——自蒸馏。
Self-supervised Data Augmentation
x ∈ R d x\in R^{d} x∈Rd: input
y ∈ { 1 , 2 , . . . , N } y\in \{1,2,...,N\} y∈{1,2,...,N}: label
L C E L_{CE} LCE: Cross Entropy loss
σ
(
⋅
;
u
)
\sigma(\cdot;u)
σ(⋅;u): softmax classifier
σ
i
(
z
;
u
)
=
e
u
i
T
z
∑
k
(
e
u
k
T
z
)
\sigma_i(z;u)=\frac{e^{u_i^Tz}}{\sum_k(e^{u_k^Tz})}
σi(z;u)=∑k(eukTz)euiTz
z
=
f
(
x
;
θ
)
z=f(x;\theta)
z=f(x;θ): embedding vector of
x
x
x, and f is a neural network with parameters
θ
\theta
θ;
1. Data Augmentation and Self-Supervision
Data Augmentation
在有监督的情况下,传统的数据增强的目的是提高目标神经网络 f f f的泛化能力.
写出其目标损失函数:
L
D
A
(
x
,
y
;
θ
,
u
)
=
E
t
∼
T
[
L
C
E
(
σ
(
f
(
x
^
;
θ
)
;
u
)
,
y
)
]
(1)
L_{DA}(x,y;\theta,u)=E_{t\sim T}[L_{CE}(\sigma(f(\hat{x};\theta);u),y)]\tag{1}
LDA(x,y;θ,u)=Et∼T[LCE(σ(f(x^;θ);u),y)](1)
T
T
T 是数据增强后的数据分布;
Self-Supervision
最近的自我监督学习文献已经表明,可以通过预测从输入信号中获得的标签来学习高级语义表示,且无需任何人工注释。
在自监督模型中,用 x ^ = t ( x ) \hat{x}=t(x) x^=t(x)表示对 x x x做了 t t t类型的数据增强。
利用自监督标签的常用方法是优化原任务和自监督任务的两个损失,同时共享它们之间的特征空间,也就是一个multi-task learning work
L
M
T
(
x
,
y
;
θ
,
u
,
v
)
=
1
M
∑
j
=
1
M
L
C
E
(
σ
(
f
(
x
^
j
;
θ
)
;
u
)
,
y
)
+
L
C
E
(
σ
(
f
(
x
^
j
;
θ
)
;
v
)
,
j
)
(2)
L_{MT}(x,y;\theta,u,v) = \frac{1}{M}\sum_{j=1}^{M}L_{CE}(\sigma(f(\hat{x}_j;\theta);u),y)+L_{CE}(\sigma(f(\hat{x}_j;\theta);v),j)\tag{2}
LMT(x,y;θ,u,v)=M1j=1∑MLCE(σ(f(x^j;θ);u),y)+LCE(σ(f(x^j;θ);v),j)(2)
{
t
j
}
j
=
1
M
\{t_j\}_{j=1}^{M}
{tj}j=1M 是一系列的预定义的数据增强方式,
M
M
M 是自监督标签的数量,
σ
(
⋅
;
v
)
\sigma(\cdot;v)
σ(⋅;v) 是自监督分类器,且
x
^
=
t
j
(
x
)
\hat{x}=t_j(x)
x^=tj(x) .
2. Eliminating invariance via joint-label classifier
作者的目的是移除式子(1)和(2)分类器不必要的标签不变性
为了达到目的,作者使用了一个joint softmax classifier ρ ( ⋅ ; w ) \rho(\cdot;w) ρ(⋅;w) 来表现joint probability P ( i , j ∣ x ^ ) = ρ i j ( z ^ ; w ) = e w i j T z ^ ∑ k , l ( e w k l T z ^ ) P(i,j|\hat{x})=\rho_{ij}(\hat{z};w)=\frac{e^{w_{ij}^T\hat{z}}}{\sum_{k,l}(e^{w_{kl}^T\hat{z}})} P(i,j∣x^)=ρij(z^;w)=∑k,l(ewklTz^)ewijTz^
因此目标函数可以写为:
L
S
D
A
(
x
,
y
;
θ
,
w
)
=
1
M
∑
j
=
1
M
L
C
E
(
ρ
(
f
(
x
^
j
;
θ
)
;
w
)
,
(
y
,
j
)
)
(3)
L_{SDA}(x,y;\theta,w)=\frac{1}{M}\sum_{j=1}^{M}L_{CE}(\rho(f(\hat{x}_j;\theta);w),(y,j))\tag{3}
LSDA(x,y;θ,w)=M1j=1∑MLCE(ρ(f(x^j;θ);w),(y,j))(3)
其中
L
C
E
(
ρ
(
z
^
;
w
)
,
(
i
,
j
)
)
=
−
log
ρ
i
j
(
z
^
;
w
)
L_{CE}(\rho(\hat{z};w),(i,j))=-\log\rho_{ij(\hat{z};w)}
LCE(ρ(z^;w),(i,j))=−logρij(z^;w) .
当所有
w
i
j
=
u
i
w_{ij} = u_i
wij=ui 时(3)就退化成了(1),当
w
i
j
=
u
i
+
v
j
w_{ij}=u_i+v_j
wij=ui+vj 时就变成了(2);
Aggregated inference
因为使用什么数据增强 方法是已知的,所以预测时不必考虑所有的
N
×
M
N\times M
N×M 个标签,只需要使用一个条件概率即可:
P
(
i
∣
x
^
j
,
j
)
=
e
w
i
j
T
z
^
j
∑
k
(
e
w
k
j
T
z
^
j
)
P(i|\hat{x}_j,j)=\frac{e^{w_{ij}^T\hat{z}_j}}{\sum_k(e^{w_{kj}^T\hat{z}_j})}
P(i∣x^j,j)=∑k(ewkjTz^j)ewijTz^j
where
z
^
j
=
f
(
x
^
j
;
θ
)
\hat{z}_j=f(\hat{x}_j;\theta)
z^j=f(x^j;θ)
针对所有的数据增强方式,作者将其条件概率聚合来增强分类器的准确率
P
a
g
g
r
e
g
e
t
e
d
(
i
∣
x
)
=
e
s
i
∑
k
=
1
N
e
s
k
P_{aggregeted}(i|x)=\frac{e^{s_i}}{\sum_{k=1}^Ne^{s_k}}
Paggregeted(i∣x)=∑k=1Neskesi
where
s
i
=
1
M
∑
j
=
1
M
w
i
j
T
z
^
j
s_i = \frac{1}{M}\sum_{j=1}^Mw_{ij}^T\hat{z}_j
si=M1j=1∑MwijTz^j
Self-distillation from aggregation
尽管这种聚合训练的方式效果很好,但其需要单网络 M M M倍的训练时间,为了解决这个问题,引入自蒸馏来将聚合网络蒸馏为一个单网络结构的网络
因此最后的目标函数为:
L
S
D
A
+
S
D
(
x
,
y
;
θ
,
w
,
u
)
=
L
S
D
A
(
x
,
y
;
θ
,
w
)
+
D
K
L
(
P
a
g
g
r
e
g
a
t
e
d
(
⋅
∣
x
)
∣
∣
σ
(
f
(
x
;
θ
)
;
u
)
)
+
β
L
C
E
(
σ
(
f
(
x
;
θ
)
;
u
)
,
y
)
(4)
L_{SDA+SD}(x,y;\theta,w,u)=L_{SDA}(x,y;\theta,w)+D_{KL}(P_{aggregated}(\cdot|x)||\sigma(f(x;\theta);u))+\beta L_{CE}(\sigma(f(x;\theta);u),y)\tag{4}
LSDA+SD(x,y;θ,w,u)=LSDA(x,y;θ,w)+DKL(Paggregated(⋅∣x)∣∣σ(f(x;θ);u))+βLCE(σ(f(x;θ);u),y)(4)
式中的
K
L
KL
KL散度就是用来自蒸馏的;