Virtual Adversarial Training文章解读+算法流程+核心代码详解

Virtual Adversarial Training

本博客仅做算法流程疏导,具体细节请参见原文

原文

查看原文请点这里

Github代码

Github代码请点这里

解读

对比Adversarial Training和VAT

VAT(Virtual Adversarial Training)和adversarial training类似。对原始训练样本添加一个比较小的扰动,会大概率使分类器分类出现错误,而我们一般希望分类器将原始样本和添加一个较小扰动的样本(加噪版本)分为同一类别,所以将扰动版本的数据也作为训练样本添加进训练,这样就增加了分类器的泛化能力。

传统的adversarial training 的扰动方向一般通过损失函数确定,即取损失函数上升的方向添加一个扰动。无标记样本没有标签,就无法算损失函数,故传统方法不适用,所以一般的adversarial training仅在监督学习中使用较多,而virtual adversarial training的创新在于能在无标记样本上实现扰动的计算,因为没用使用标签进行运算,而是用模型预测的结果替代标签,类似于persudo label,这就是virtual的含义

Adversarial Training

adversarial training的数学表达如下,其中样本及标记 ( x , y ) (x,y) (x,y),当前epoch模型的参数 θ \theta θ:
损失函数 J ( θ ) = 1 N ∑ i = 1 N L ( x , θ ) J(\theta)=\frac{1}{N}\sum^{N}_{i=1}L(x,\theta) J(θ)=N1i=1NL(x,θ)
其中,单项损失计算表达式为: L ( x , θ ) = D ( y , p ( y ∣ x + r , θ ) ) L(x,\theta)=D(y,p(y|x+r,\theta)) L(x,θ)=D(y,p(yx+r,θ))
扰动方向 r = a r g m a x ∣ r ∣ < ξ D ( y , p ( y ∣ x + r , θ ) ) r=argmax_{|r|<\xi}D(y,p(y|x+r,\theta)) r=argmaxr<ξD(y,p(yx+r,θ))

简单叙述为:找到一个扰动 r r r​,且 r r r​的大小受限,即 ∣ r ∣ < ξ |r|<\xi r<ξ​,使其损失函数 L ( x , θ ) = D ( y , p ( y ∣ x + r , θ ) ) L(x,\theta)=D(y,p(y|x+r,\theta)) L(x,θ)=D(y,p(yx+r,θ))​​取最大值​,即在此 r r r​下上升最多。

VAT

同样形式的,virtual adversarial training 的数学表达式如下,其中其中样本及标记 ( x , y ) (x,y) (x,y),当前epoch模型的参数 θ \theta θ,前一个epoch的模型参数为 θ ^ \hat{\theta} θ^
损失函数同上形式: J ( θ ) = 1 N ∑ i = 1 N L ( x , θ ) J(\theta)=\frac{1}{N}\sum^N_{i=1}L(x,\theta) J(θ)=N1i=1NL(x,θ)
单项损失表达式不同(LDS称为局部平滑度): L ( x , θ ) = D ( p ( y ∣ x , θ ^ ) , p ( y ∣ x + r , θ ) ) = L D S ( x , θ ) L(x,\theta)=D(p(y|x,\hat\theta),p(y|x+r,\theta))=LDS(x,\theta) L(x,θ)=D(p(yx,θ^),p(yx+r,θ))=LDS(x,θ)
扰动方向 r = a r g m a x ∣ r ∣ < ξ D ( p ( y ∣ x , θ ) , p ( y ∣ x + r , θ ) ) r=argmax_{|r|<\xi}D(p(y|x,\theta),p(y|x+r,\theta)) r=argmaxr<ξD(p(yx,θ),p(yx+r,θ))

简单叙述为:找到一个扰动 r r r​,且 r r r​​的大小受限,​即 ∣ r ∣ < ξ |r|<\xi r<ξ​​,使其损失函数 L D S ( x , θ ) LDS(x,\theta) LDS(x,θ)​取的最大值,即在此 r r r​下上升最多。

代码详解

代码核心就一个VAT_Loss的计算。整个框架的Loss=Classfier_Loss + VAT_Loss。其中Classfier_Loss损失函数为一般的监督网络的损失函数。VAT_Loss计算如下:

def vat_loss(model, ul_x, ul_y, xi=1e-6, eps=2.5, num_iters=1):
    # find r_adv
    d = torch.Tensor(ul_x.size()).normal_()
    for i in range(num_iters):
        d = xi *_l2_normalize(d)
        d = Variable(d.cuda(), requires_grad=True)
        y_hat = model(ul_x + d)
        delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
        delta_kl.backward()
        d = d.clone().cpu()
        model.zero_grad()
    d = _l2_normalize(d)
    d = Variable(d.cuda())
    r_adv = eps * d
    # compute lds
    y_hat = model(ul_x + r_adv.detach())
    delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
    return delta_kl

其中对r_adv的计算采用的是一种快速计算方法。具体理论请查阅原文

v_loss = vat_loss(model, inputs_All, logits_All, eps=args.epsilon)
loss = v_loss+ce_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()

完整损失函数Loss=Classfier_Loss + VAT_Loss反向梯度传播更新网络即可。

  • 4
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值