Virtual Adversarial Training
本博客仅做算法流程疏导,具体细节请参见原文
原文
Github代码
解读
对比Adversarial Training和VAT
VAT(Virtual Adversarial Training)和adversarial training类似。对原始训练样本添加一个比较小的扰动,会大概率使分类器分类出现错误,而我们一般希望分类器将原始样本和添加一个较小扰动的样本(加噪版本)分为同一类别,所以将扰动版本的数据也作为训练样本添加进训练,这样就增加了分类器的泛化能力。
传统的adversarial training 的扰动方向一般通过损失函数确定,即取损失函数上升的方向添加一个扰动。无标记样本没有标签,就无法算损失函数,故传统方法不适用,所以一般的adversarial training仅在监督学习中使用较多,而virtual adversarial training的创新在于能在无标记样本上实现扰动的计算,因为没用使用标签进行运算,而是用模型预测的结果替代标签,类似于persudo label,这就是virtual的含义
Adversarial Training
adversarial training的数学表达如下,其中样本及标记
(
x
,
y
)
(x,y)
(x,y),当前epoch模型的参数
θ
\theta
θ:
损失函数:
J
(
θ
)
=
1
N
∑
i
=
1
N
L
(
x
,
θ
)
J(\theta)=\frac{1}{N}\sum^{N}_{i=1}L(x,\theta)
J(θ)=N1∑i=1NL(x,θ)
其中,单项损失计算表达式为:
L
(
x
,
θ
)
=
D
(
y
,
p
(
y
∣
x
+
r
,
θ
)
)
L(x,\theta)=D(y,p(y|x+r,\theta))
L(x,θ)=D(y,p(y∣x+r,θ))
扰动方向:
r
=
a
r
g
m
a
x
∣
r
∣
<
ξ
D
(
y
,
p
(
y
∣
x
+
r
,
θ
)
)
r=argmax_{|r|<\xi}D(y,p(y|x+r,\theta))
r=argmax∣r∣<ξD(y,p(y∣x+r,θ))
简单叙述为:找到一个扰动 r r r,且 r r r的大小受限,即 ∣ r ∣ < ξ |r|<\xi ∣r∣<ξ,使其损失函数 L ( x , θ ) = D ( y , p ( y ∣ x + r , θ ) ) L(x,\theta)=D(y,p(y|x+r,\theta)) L(x,θ)=D(y,p(y∣x+r,θ))取最大值,即在此 r r r下上升最多。
VAT
同样形式的,virtual adversarial training 的数学表达式如下,其中其中样本及标记
(
x
,
y
)
(x,y)
(x,y),当前epoch模型的参数
θ
\theta
θ,前一个epoch的模型参数为
θ
^
\hat{\theta}
θ^:
损失函数同上形式:
J
(
θ
)
=
1
N
∑
i
=
1
N
L
(
x
,
θ
)
J(\theta)=\frac{1}{N}\sum^N_{i=1}L(x,\theta)
J(θ)=N1∑i=1NL(x,θ)
单项损失表达式不同(LDS称为局部平滑度):
L
(
x
,
θ
)
=
D
(
p
(
y
∣
x
,
θ
^
)
,
p
(
y
∣
x
+
r
,
θ
)
)
=
L
D
S
(
x
,
θ
)
L(x,\theta)=D(p(y|x,\hat\theta),p(y|x+r,\theta))=LDS(x,\theta)
L(x,θ)=D(p(y∣x,θ^),p(y∣x+r,θ))=LDS(x,θ)
扰动方向:
r
=
a
r
g
m
a
x
∣
r
∣
<
ξ
D
(
p
(
y
∣
x
,
θ
)
,
p
(
y
∣
x
+
r
,
θ
)
)
r=argmax_{|r|<\xi}D(p(y|x,\theta),p(y|x+r,\theta))
r=argmax∣r∣<ξD(p(y∣x,θ),p(y∣x+r,θ))
简单叙述为:找到一个扰动
r
r
r,且
r
r
r的大小受限,即
∣
r
∣
<
ξ
|r|<\xi
∣r∣<ξ,使其损失函数
L
D
S
(
x
,
θ
)
LDS(x,\theta)
LDS(x,θ)取的最大值,即在此
r
r
r下上升最多。
代码详解
代码核心就一个VAT_Loss的计算。整个框架的Loss=Classfier_Loss + VAT_Loss。其中Classfier_Loss损失函数为一般的监督网络的损失函数。VAT_Loss计算如下:
def vat_loss(model, ul_x, ul_y, xi=1e-6, eps=2.5, num_iters=1):
# find r_adv
d = torch.Tensor(ul_x.size()).normal_()
for i in range(num_iters):
d = xi *_l2_normalize(d)
d = Variable(d.cuda(), requires_grad=True)
y_hat = model(ul_x + d)
delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
delta_kl.backward()
d = d.clone().cpu()
model.zero_grad()
d = _l2_normalize(d)
d = Variable(d.cuda())
r_adv = eps * d
# compute lds
y_hat = model(ul_x + r_adv.detach())
delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
return delta_kl
其中对r_adv的计算采用的是一种快速计算方法。具体理论请查阅原文
v_loss = vat_loss(model, inputs_All, logits_All, eps=args.epsilon)
loss = v_loss+ce_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
完整损失函数Loss=Classfier_Loss + VAT_Loss反向梯度传播更新网络即可。