概
也算是一种对抗训练吧, 有区别的是构造对抗样本的方式, 以及用的是惩罚项而非仅用对抗样本训练.
主要内容
考虑干净样本
x
x
x和扰动
v
v
v, 则我们自然希望
min
θ
max
∥
v
∥
≤
ϵ
ℓ
(
x
+
v
;
θ
)
−
ℓ
(
x
)
,
(*)
\tag{*} \min_{\theta} \max_{\|v\|\le \epsilon} \ell(x+v;\theta)- \ell(x),
θmin∥v∥≤ϵmaxℓ(x+v;θ)−ℓ(x),(*)
其中
ℓ
\ell
ℓ是分类损失. 注意到右端项的二阶近似为
Q
(
v
;
x
)
:
=
∇
x
ℓ
(
x
)
⋅
v
+
1
2
v
T
∇
x
ℓ
(
x
)
v
.
(3)
\tag{3} Q(v;x):= \nabla_x \ell(x) \cdot v + \frac{1}{2} v^T \nabla_x \ell(x) v.
Q(v;x):=∇xℓ(x)⋅v+21vT∇xℓ(x)v.(3)
故我们可以转而优化此近似项. 当然, 一般的AT方法是用project gradient去逼近右端项, 假设前者
v
Q
=
arg
max
∥
v
∥
p
≤
ϵ
Q
(
v
;
x
)
,
(4)
\tag{4} v_Q = \arg \max_{\|v\|_p \le \epsilon} Q(v;x),
vQ=arg∥v∥p≤ϵmaxQ(v;x),(4)
后者
v
A
=
arg
max
∥
v
∥
p
≤
ϵ
ℓ
(
x
+
v
)
.
v_A = \arg \max_{\|v\|_p \le \epsilon} \ell (x+v).
vA=arg∥v∥p≤ϵmaxℓ(x+v).
那么二者的差距有下面的定理保证
说实话, 这个定理没多大意义.
如果单纯优化(*)没法带来精度, 所以构造一个正则化项
min
θ
E
x
∼
D
[
ℓ
(
x
)
+
r
⋅
ℓ
Q
(
x
)
]
,
\min_{\theta} \mathbb{E}_{x\sim \mathcal{D}} [\ell(x)+r \cdot \ell_Q(x)],
θminEx∼D[ℓ(x)+r⋅ℓQ(x)],
其中
ℓ
Q
(
x
)
=
ℓ
(
x
+
v
Q
)
−
ℓ
(
x
)
\ell_Q(x)=\ell(x+v_Q)-\ell(x)
ℓQ(x)=ℓ(x+vQ)−ℓ(x).
注: 有一个疑问, 按照道理 r ∈ ( 0 , 1 ) r \in (0, 1) r∈(0,1), 可是论文的实验是 ( 0.5 , 1.5 ) (0.5, 1.5) (0.5,1.5), 而且有几个实验挑了的确大于1, 这不就意味着需要 min ( 1 − r ) ℓ ( x ) \min (1-r)\ell(x) min(1−r)ℓ(x), 这不就让分类变差了?
(4)式的求解
作者利用Frank-Wofle (FW) 去求解(4)式, 即
{
s
k
:
=
arg
max
∥
s
∥
p
≤
ϵ
s
⋅
∇
v
Q
(
v
k
)
v
k
+
1
:
=
(
1
−
γ
k
)
v
k
+
γ
k
s
k
,
(7)
\tag{7} \left \{ \begin{array}{l} s^k := \arg \max_{\|s\|_p\le \epsilon} \: s \cdot \nabla_v Q(v^k)\\ v^{k+1} := (1-\gamma^k) v^k + \gamma^k s^k, \end{array} \right.
{sk:=argmax∥s∥p≤ϵs⋅∇vQ(vk)vk+1:=(1−γk)vk+γksk,(7)
其中
v
k
=
2
k
+
2
v^k=\frac{2}{k+2}
vk=k+22,
v
0
=
ϵ
g
/
∥
g
∥
p
,
g
=
∇
x
ℓ
(
x
)
v^0=\epsilon g/\|g\|_p, \: g=\nabla_x \ell(x)
v0=ϵg/∥g∥p,g=∇xℓ(x). (7)式的第一步式可以显示求解的
s
k
=
P
F
W
(
v
k
;
p
)
=
α
⋅
s
g
n
(
∇
v
Q
(
v
k
)
i
)
∣
∇
v
Q
(
v
k
)
i
∣
p
/
q
,
(8)
\tag{8} s^k=P_{FW}(v^k;p)=\alpha \cdot \mathrm{sgn} (\nabla_v Q(v^k)_i) |\nabla_v Q(v^k)_i|^{p/q},
sk=PFW(vk;p)=α⋅sgn(∇vQ(vk)i)∣∇vQ(vk)i∣p/q,(8)
其中
α
\alpha
α使得
∥
s
k
∥
p
=
ϵ
\|s^k\|_p=\epsilon
∥sk∥p=ϵ,
∣
x
∣
m
|x|^m
∣x∣m是逐项的.
因为
∇
x
Q
(
v
)
=
∇
x
ℓ
(
x
)
+
∇
x
2
ℓ
(
x
)
v
,
(9)
\tag{9} \nabla_x Q(v) = \nabla_x \ell(x) + \nabla^2_x \ell(x)v,
∇xQ(v)=∇xℓ(x)+∇x2ℓ(x)v,(9)
而计算hessian矩阵需要大量的计算, 故采用差分逼近
FE:
∇
x
2
ℓ
(
x
)
v
≈
∇
x
ℓ
(
x
+
h
v
)
−
∇
x
ℓ
(
x
)
]
h
,
(10)
\tag{10} \nabla_x^2 \ell(x)v \approx \frac{\nabla_x \ell(x+hv)-\nabla_x \ell(x)]}{h},
∇x2ℓ(x)v≈h∇xℓ(x+hv)−∇xℓ(x)],(10)
CD:
∇
x
2
ℓ
(
x
)
v
≈
∇
x
ℓ
(
x
+
h
v
)
−
∇
x
ℓ
(
x
−
h
)
]
2
h
.
(11)
\tag{11} \nabla_x^2 \ell(x)v \approx \frac{\nabla_x \ell(x+hv)-\nabla_x \ell(x-h)]}{2h}.
∇x2ℓ(x)v≈2h∇xℓ(x+hv)−∇xℓ(x−h)].(11)
超参数
(
h
,
r
)
(h, r)
(h,r).
CIFAR10:
L
2
L_2
L2: FE(3): (1.15, 1.05) , CD(3): (0.95, 0.999);
L
∞
L_{\infty}
L∞: FE(3): (1.05, 1.05), CD(3): (0.95, 1.15).