虚拟对抗训练(VAT)原理和代码解析
微软在ACL20发表了一篇Adversarial Training for Large Neural Language Models,对应的代码ALUM,这是一篇首次在大规模语料做对抗式训练的语言模型研究,提出了ALUM通用的对抗式训练的算法,并且在当前预训练模型上取得SOTA。此研究目的是解决当前的预训练模型(文中用BERT和ROBERT)泛化性和鲁棒性不足的,并且当前对抗训练虽然可以增强鲁棒性,但会损害泛化性的问题。作者还指出ALUM可以在预训练和下游任务都可以使用。
预备知识
此模型是一种半监督学习的模型,相比于其他对抗式学习不同之处,例如FGSM、FGM、PGD等,对于ALUM是加入了无标签数据去优化模型参数。所以了解其他的对抗学习之后,再看看论文发现原理不会很难,以下列出几点需要提前掌握的知识点。
对抗式学习(FGSM、FGM、PGD等)
强烈推荐看【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现这篇blog。简单说对抗式训练是做防御和攻击的训练过程,即在输入 x x x上加入一个扰动 r r r, r r r是利用模型Loss对于 x x x的梯度加上正则化得到,然后利用加入扰动的x进入模型,再进行一次训练。
DL散度Loss
DL散度是量化两种概率分布P和Q之间差异的方式 :
D
(
p
∣
∣
q
)
=
∑
p
(
x
i
)
∗
(
l
o
g
(
p
(
x
i
)
−
l
o
g
q
(
x
i
)
)
D(p||q)=\sum p(x_i)*(log(p(x_i)-logq(x_i))
D(p∣∣q)=∑p(xi)∗(log(p(xi)−logq(xi))
在论文中
p
p
p是实际样本输入预训练模型输出
l
o
g
i
t
s
logits
logits,
q
q
q是指对抗样本输入预训练模型后输出
a
d
v
_
l
o
g
i
t
s
adv\_logits
adv_logits,所以这里得到模型其中的一部分Loss。
p = torch.tensor([[0.7, 0.2, 0.1], [0.2,0.2, 0.6], [0.3, 0.2, 0.5]])
q = torch.tensor([[0.6, 0.3, 0.1], [0.2,0.2, 0.6], [0.3, 0.1, 0.6]])
torch.nn.functional.kl_div(q.log_softmax(dim=-1), p.softmax(dim=-1)
模型过程
论文中给出了具体的算法过程,如下:
大概说一下具体的参数和步骤,首先先说参数:
参数 | 描述 | 取值 |
---|---|---|
T | epoch | - |
K | 要做多少次扰动更新 | 理论越多效果越好,但是就expensive,论文中 K=1 |
∏ \prod ∏ | 正则化方法 | L0,L1,L2中选择 |
α \alpha α | 增强对抗学习的比例 | 预训练为10,下游任务为1 |
η \eta η | 扰动的学习率 | 1 × 1 0 − 3 1 × 10^{−3} 1×10−3 |
τ \tau τ | 全局学习率 | 1 × 1 0 − 5 1 × 10^{−5} 1×10−5 |
θ \theta θ | 模型参数 | - |
对于模型算法过程确实是不复杂,所以打算按照图片中的行号一行行说明:
- 循环epoch
- 循环数据集,每次产生一个batch_size大小的数据
- 生成一个扰动 δ \delta δ, δ \delta δ服从均作为0,方差为1
- 循环K次,理论K越大效果越好,实际使用K=1,减少计算量
- 计算实际输入的输出和对抗样本的实际输入的DL散度Loss,并计算梯度
- 扰动正则化
- 循环K次结束
- 计算模型的Loss(带标签数据losss+虚拟对抗Loss)计算梯度更新参数,α是增强对抗学习的比例,预训练设置为10,下游任务设置为1。
下图展示了虚拟对抗训练的过程,我们可以看出对抗样本是由扰动加到输入的Embed空间上得到,然后原始输入和对抗样本分别计算得到两个输出,原始输出与标签计算得到Loss,对抗样本需要和原始输出计算得到Adv Loss,最后我们需要的是最小化Loss,最大化Adv Loss,最后我们的目标是:
m
i
n
θ
E
(
x
,
y
)
D
[
l
(
f
(
x
;
θ
)
,
y
)
+
α
m
a
x
δ
l
(
f
(
x
+
δ
;
θ
)
,
f
(
x
;
θ
)
)
]
min_{\theta}E_{(x,y) D}[l(f(x;\theta),y)+αmax_{δ}l(f(x+δ;\theta),f(x; θ))]
minθE(x,y)D[l(f(x;θ),y)+αmaxδl(f(x+δ;θ),f(x;θ))]
代码干货
代码已经开源,项目是以robert进行了实验,我们只需要关心 adv_masked_lm.py 和 adv_masked_lm_task.py 这两个文件。
adv_masked_lm.py:虚拟对抗训练代码
adv_masked_lm_task.py:训练mlm模型,其中包括超参数的设置
虚拟对抗训练代码
本人使用中是剥离出adv_masked_lm.py,方便能在torch中使用。
import torch
import torch.nn.functional as F
def kl(inputs, targets, reduction="sum"):
"""
计算kl散度
inputs:tensor,logits
targets:tensor,logits
"""
loss = F.kl_div(F.log_softmax(inputs, dim=-1),
F.softmax(targets, dim=-1),
reduction=reduction)
return loss
def adv_project(grad, norm_type='inf', eps=1e-6):
"""
L0,L1,L2正则,对于扰动计算
"""
if norm_type == 'l2':
direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
elif norm_type == 'l1':
direction = grad.sign()
else:
direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
return direction
def virtual_adversarial_training(model, hidden_status, token_type_ids, attention_mask, logits):
"""
虚拟对抗式训练
model: nn.Module, 模型
hidden_status:tensor,input的embedded表示
token_type_ids:tensor,bert中的token_type_ids,A B 句子
attention_mask:tensor,bert中的attention_mask,对paddding mask
logits:tensor,input的输出
"""
embed = hidden_status
# 初始扰动 r
noise = embed.data.new(embed.size()).normal_(0, 1) * 1e-5
noise.requires_grad_()
# x + r
new_embed = embed.data.detach() + noise
adv_output = model(inputs_embeds=new_embed,
token_type_ids=token_type_ids,
attention_mask=attention_mask)
adv_logits = adv_output[0]
adv_loss = kl(adv_logits, logits.detach(), reduction="batchmean")
delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)
norm = delta_grad.norm()
# 梯度消失,退出
if torch.isnan(norm) or torch.isinf(norm):
return None
# line 6 inner sum
noise = noise + delta_grad * 1e-3
# line 6 projection
noise = adv_project(noise, norm_type='l2', eps=1e-6)
new_embed = embed.data.detach() + noise
new_embed = new_embed.detach()
# 在进行一次训练
adv_output = model(inputs_embeds=new_embed,
token_type_ids=token_type_ids,
attention_mask=attention_mask)
adv_logits = adv_output[0]
adv_loss_f = kl(adv_logits, logits.detach())
adv_loss_b = kl(logits, adv_logits.detach())
# 在预训练时设置为10,下游任务设置为1
adv_loss = (adv_loss_f + adv_loss_b) * 1
return adv_loss
使用方法
以下是使用nezha-bert训练的调用代码:
for input_ids, token_type_ids, attention_mask, output_ids, _ in tqdm(train_loader):
step += 1
input_ids = input_ids.long().to(device)
token_type_ids = token_type_ids.long().to(device)
attention_mask = attention_mask.long().to(device)
output_ids = output_ids.long().to(device)
optimizer.zero_grad()
# 混合精度计算,训练速度接近提高了1/2
with autocast():
output = model(input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
labels=output_ids)
loss = output[0]
if args.use_adv == 'vat':
logits = output[1]
hidden_status = output[2][0]
adv_loss = virtual_adversarial_training(model, hidden_status, token_type_ids, attention_mask, logits)
if adv_loss:
train_adv_loss += adv_loss
loss = adv_loss * 10 + loss
train_loss += loss
loss.backward()
optimizer.step()
实验和结论
本人再使用nezha和vat的训练相似度计算模型auc能达到97.2%,acc达到91%,对比了没用使用vat,最终auc没能上97%,acc在89%左右,且训练过程中loss波动较大,所以证明了vat在训练过程中是有效的。