概
在训练的时候对权重加扰动能增强泛化性.
主要内容
如上图所示, 一般的训练方法虽然能够收敛到一个不错的局部最优点, 但是往往这个局部最优点附近是非常不光滑的, 即对权重 w w w添加微小的扰动 w + ϵ w+\epsilon w+ϵ 可能就会导致不好的结果, 作者认为这与模型的泛化性有很大关系(实际上已有别的文章提出这一观点).
作者给出如下的理论分析:
在满足一定条件下有
L
D
(
w
)
≤
max
∥
ϵ
∥
2
≤
ρ
L
S
(
w
+
ϵ
)
+
h
(
∥
w
∥
2
2
/
ρ
2
)
.
L_{\mathscr{D}} (w) \le \max_{\|\epsilon \|_2 \le \rho} L_{\mathcal{S}} (w + \epsilon) + h(\|w\|_2^2/\rho^2).
LD(w)≤∥ϵ∥2≤ρmaxLS(w+ϵ)+h(∥w∥22/ρ2).
其中
h
h
h是一个严格单调递增函数,
L
S
L_{\mathcal{S}}
LS是在训练集
S
\mathcal{S}
S上的损失,
L
D
(
w
)
=
E
(
x
,
y
)
∼
D
[
l
(
x
,
y
;
w
)
]
.
L_{\mathscr{D}}(w) = \mathbb{E}_{(x, y) \sim \mathscr{D}} [l(x, y;w)].
LD(w)=E(x,y)∼D[l(x,y;w)].
如果把
h
(
∥
w
∥
2
2
/
ρ
2
)
h(\|w\|_2^2/\rho^2)
h(∥w∥22/ρ2)看成
λ
∥
w
∥
2
2
\lambda \|w\|_2^2
λ∥w∥22(即常用的weight decay), 我们的目标函数可以认为是
min
w
L
S
S
A
M
(
w
)
+
λ
∥
w
∥
2
2
,
\min_w L_{\mathcal{S}}^{SAM} (w) + \lambda \|w\|_2^2,
wminLSSAM(w)+λ∥w∥22,
L
S
S
A
M
(
w
)
:
=
max
∥
ϵ
∥
p
≤
ρ
L
S
(
w
+
ϵ
)
,
L_{\mathcal{S}}^{SAM}(w) := \max_{\|\epsilon \|_p \le \rho} L_{\mathcal{S}} (w + \epsilon),
LSSAM(w):=∥ϵ∥p≤ρmaxLS(w+ϵ),
注: 这里
∥
⋅
∥
p
\|\cdot \|_p
∥⋅∥p而并不仅限于
∥
⋅
∥
2
\|\cdot \|_2
∥⋅∥2.
采用近似的方法求解上面的问题(就和对抗样本一样):
ϵ
∗
(
w
)
:
=
arg
max
∥
ϵ
∥
p
≤
ρ
L
S
(
w
+
ϵ
)
≈
arg
max
∥
ϵ
∥
p
≤
ρ
L
S
(
w
)
+
ϵ
T
∇
w
L
S
(
w
)
=
arg
max
∥
ϵ
∥
p
≤
ρ
ϵ
T
∇
w
L
S
(
w
)
.
\epsilon^* (w) := \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} L_{\mathcal{S}}(w + \epsilon) \approx \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} L_{\mathcal{S}}(w) + \epsilon^T \nabla_w L_{\mathcal{S}}(w) = \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} \epsilon^T \nabla_w L_{\mathcal{S}}(w).
ϵ∗(w):=∥ϵ∥p≤ρargmaxLS(w+ϵ)≈∥ϵ∥p≤ρargmaxLS(w)+ϵT∇wLS(w)=∥ϵ∥p≤ρargmaxϵT∇wLS(w).
就是一个对偶范数的问题.
虽然
ϵ
∗
(
w
)
\epsilon^*(w)
ϵ∗(w)实际上是和
w
w
w有关的, 但是在实际中只是当初普通的量带入, 这样就不用计算二阶导数了, 即
∇
w
L
S
S
A
M
(
w
)
≈
∇
w
L
S
(
w
)
∣
w
+
ϵ
^
(
w
)
.
\nabla_w L_{\mathcal{S}}^{SAM}(w) \approx \nabla_w L_{\mathcal{S}}(w) |_{w + \hat{\epsilon}(w)}.
∇wLSSAM(w)≈∇wLS(w)∣w+ϵ^(w).
实验结果非常好, 不仅能够提高普通的正确率, 在标签受到污染的情况下也能有很好的鲁棒性.