虚拟对抗训练(VAT)原理和代码解析

本文介绍了微软在ACL20发布的Adversarial Training for Large Neural Language Models,该研究首次在大规模语料上进行对抗式训练,提出ALUM算法,解决了预训练模型的泛化性和鲁棒性问题。ALUM在预训练和下游任务中都表现出色,其原理包括对抗式学习、DL散度Loss等。代码实现中详细展示了虚拟对抗训练过程,并给出了使用示例,证明了VAT的有效性。
摘要由CSDN通过智能技术生成

虚拟对抗训练(VAT)原理和代码解析

微软在ACL20发表了一篇Adversarial Training for Large Neural Language Models,对应的代码ALUM,这是一篇首次在大规模语料做对抗式训练的语言模型研究,提出了ALUM通用的对抗式训练的算法,并且在当前预训练模型上取得SOTA。此研究目的是解决当前的预训练模型(文中用BERT和ROBERT)泛化性和鲁棒性不足的,并且当前对抗训练虽然可以增强鲁棒性,但会损害泛化性的问题。作者还指出ALUM可以在预训练和下游任务都可以使用。

预备知识

此模型是一种半监督学习的模型,相比于其他对抗式学习不同之处,例如FGSM、FGM、PGD等,对于ALUM是加入了无标签数据去优化模型参数。所以了解其他的对抗学习之后,再看看论文发现原理不会很难,以下列出几点需要提前掌握的知识点。

对抗式学习(FGSM、FGM、PGD等)

强烈推荐看【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现这篇blog。简单说对抗式训练是做防御和攻击的训练过程,即在输入 x x x上加入一个扰动 r r r r r r是利用模型Loss对于 x x x的梯度加上正则化得到,然后利用加入扰动的x进入模型,再进行一次训练。

DL散度Loss

DL散度是量化两种概率分布P和Q之间差异的方式 :
D ( p ∣ ∣ q ) = ∑ p ( x i ) ∗ ( l o g ( p ( x i ) − l o g q ( x i ) ) D(p||q)=\sum p(x_i)*(log(p(x_i)-logq(x_i)) D(pq)=p(xi)(log(p(xi)logq(xi))
在论文中 p p p是实际样本输入预训练模型输出 l o g i t s logits logits q q q是指对抗样本输入预训练模型后输出 a d v _ l o g i t s adv\_logits adv_logits,所以这里得到模型其中的一部分Loss。

p = torch.tensor([[0.7, 0.2, 0.1], [0.2,0.2, 0.6], [0.3, 0.2, 0.5]])
q = torch.tensor([[0.6, 0.3, 0.1], [0.2,0.2, 0.6], [0.3, 0.1, 0.6]])
torch.nn.functional.kl_div(q.log_softmax(dim=-1), p.softmax(dim=-1)

模型过程

论文中给出了具体的算法过程,如下:
在这里插入图片描述
大概说一下具体的参数和步骤,首先先说参数:

参数描述取值
Tepoch-
K要做多少次扰动更新理论越多效果越好,但是就expensive,论文中 K=1
∏ \prod 正则化方法L0,L1,L2中选择
α \alpha α增强对抗学习的比例预训练为10,下游任务为1
η \eta η扰动的学习率 1 × 1 0 − 3 1 × 10^{−3} 1×103
τ \tau τ全局学习率 1 × 1 0 − 5 1 × 10^{−5} 1×105
θ \theta θ模型参数-

对于模型算法过程确实是不复杂,所以打算按照图片中的行号一行行说明:

  1. 循环epoch
  2. 循环数据集,每次产生一个batch_size大小的数据
  3. 生成一个扰动 δ \delta δ, δ \delta δ服从均作为0,方差为1
  4. 循环K次,理论K越大效果越好,实际使用K=1,减少计算量
  5. 计算实际输入的输出和对抗样本的实际输入的DL散度Loss,并计算梯度
  6. 扰动正则化
  7. 循环K次结束
  8. 计算模型的Loss(带标签数据losss+虚拟对抗Loss)计算梯度更新参数,α是增强对抗学习的比例,预训练设置为10,下游任务设置为1。

下图展示了虚拟对抗训练的过程,我们可以看出对抗样本是由扰动加到输入的Embed空间上得到,然后原始输入和对抗样本分别计算得到两个输出,原始输出与标签计算得到Loss,对抗样本需要和原始输出计算得到Adv Loss,最后我们需要的是最小化Loss,最大化Adv Loss,最后我们的目标是:
m i n θ E ( x , y ) D [ l ( f ( x ; θ ) , y ) + α m a x δ l ( f ( x + δ ; θ ) , f ( x ; θ ) ) ] min_{\theta}E_{(x,y) D}[l(f(x;\theta),y)+αmax_{δ}l(f(x+δ;\theta),f(x; θ))] minθE(x,y)D[l(f(x;θ),y)+αmaxδl(f(x+δ;θ),f(x;θ))]
虚拟对抗训练的模型流程

代码干货

代码已经开源,项目是以robert进行了实验,我们只需要关心 adv_masked_lm.py 和 adv_masked_lm_task.py 这两个文件。

adv_masked_lm.py:虚拟对抗训练代码
adv_masked_lm_task.py:训练mlm模型,其中包括超参数的设置
在这里插入图片描述

虚拟对抗训练代码

本人使用中是剥离出adv_masked_lm.py,方便能在torch中使用。

import torch
import torch.nn.functional as F


def kl(inputs, targets, reduction="sum"):
	"""
	计算kl散度
	inputs:tensor,logits
	targets:tensor,logits
	"""
    loss = F.kl_div(F.log_softmax(inputs, dim=-1),
                    F.softmax(targets, dim=-1),
                    reduction=reduction)
    return loss


def adv_project(grad, norm_type='inf', eps=1e-6):
	"""
	L0,L1,L2正则,对于扰动计算
	"""
    if norm_type == 'l2':
        direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
    elif norm_type == 'l1':
        direction = grad.sign()
    else:
        direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
    return direction


def virtual_adversarial_training(model, hidden_status, token_type_ids, attention_mask, logits):
	"""
	虚拟对抗式训练
	model: nn.Module, 模型
	hidden_status:tensor,input的embedded表示
	token_type_ids:tensor,bert中的token_type_ids,A B 句子
	attention_mask:tensor,bert中的attention_mask,对paddding mask
	logits:tensor,input的输出
	"""
    embed = hidden_status
    # 初始扰动 r
    noise = embed.data.new(embed.size()).normal_(0, 1) * 1e-5
    noise.requires_grad_()
    # x + r
    new_embed = embed.data.detach() + noise
    adv_output = model(inputs_embeds=new_embed,
                       token_type_ids=token_type_ids,
                       attention_mask=attention_mask)
    adv_logits = adv_output[0]
    adv_loss = kl(adv_logits, logits.detach(), reduction="batchmean")
    delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)
    norm = delta_grad.norm()

	# 梯度消失,退出
    if torch.isnan(norm) or torch.isinf(norm):
        return None

    # line 6 inner sum
    noise = noise + delta_grad * 1e-3
    # line 6 projection
    noise = adv_project(noise, norm_type='l2', eps=1e-6)
    new_embed = embed.data.detach() + noise
    new_embed = new_embed.detach()
    # 在进行一次训练
    adv_output = model(inputs_embeds=new_embed,
                       token_type_ids=token_type_ids,
                       attention_mask=attention_mask)
    adv_logits = adv_output[0]
    adv_loss_f = kl(adv_logits, logits.detach())
    adv_loss_b = kl(logits, adv_logits.detach())
    # 在预训练时设置为10,下游任务设置为1
    adv_loss = (adv_loss_f + adv_loss_b) * 1

    return adv_loss

使用方法

以下是使用nezha-bert训练的调用代码:

for input_ids, token_type_ids, attention_mask, output_ids, _ in tqdm(train_loader):
    step += 1
    input_ids = input_ids.long().to(device)
    token_type_ids = token_type_ids.long().to(device)
    attention_mask = attention_mask.long().to(device)
    output_ids = output_ids.long().to(device)
    optimizer.zero_grad()
    # 混合精度计算,训练速度接近提高了1/2
    with autocast():
        output = model(input_ids,
                       token_type_ids=token_type_ids,
                       attention_mask=attention_mask,
                       labels=output_ids)
        loss = output[0]
        if args.use_adv == 'vat':
            logits = output[1]
            hidden_status = output[2][0]
            adv_loss = virtual_adversarial_training(model, hidden_status, token_type_ids, attention_mask, logits)
            if adv_loss:
                train_adv_loss += adv_loss
                loss = adv_loss * 10 + loss

    train_loss += loss
    loss.backward()
    optimizer.step()

实验和结论

本人再使用nezha和vat的训练相似度计算模型auc能达到97.2%,acc达到91%,对比了没用使用vat,最终auc没能上97%,acc在89%左右,且训练过程中loss波动较大,所以证明了vat在训练过程中是有效的。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值