class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=0.5, emb_name='bertembeddings.word_embeddings_layer.weight'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='bertembeddings.word_embeddings_layer.weight'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
1.注意attack需要修改emb_name,restore函数也需要修改emb_name
restore函数如果忘记修改emb_name,训练效果可能会拉跨
2.注意epsilon需要调整
有的时候epsilon的值需要调整的更大一些,从而能够避免扰动
调用roberta进行对抗训练的时候
class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=0.3, emb_name='robertaembeddings.word_embeddings_layer.weight'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
print('fgm attack')
#这里加入fgm attack来判断是否进行对抗训练了
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='robertaembeddings.word_embeddings_layer.weight'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
print('fgm restore')
#这里加入fgm restore判断是否恢复参数了
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
对抗训练操作
fgm = FGM(model)
for batch_input, batch_label in data:
# 正常训练
loss = model(batch_input, batch_label)
loss.backward() # 反向传播,得到正常的grad
# 对抗训练
fgm.attack() # 在embedding上添加对抗扰动
loss_adv = model(batch_input, batch_label)
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
fgm.restore() # 恢复embedding参数
optimizer.step()# 梯度下降,更新参数
model.zero_grad()