class AWP:
def __init__(
self,
model,
optimizer,
adv_param = 'weight',
adv_lr = 1,
adv_eps = 0.2,
start_step = 0,
adv_step = 1,
scaler = None
):
self.model = model # 模型
self.optimizer = optimizer # 优化器
self.adv_param = adv_param # 对哪些参数进行对抗训练
self.adv_lr = adv_lr # AWP学习率
self.adv_eps = adv_eps # AWP扰动大小
self.start_step = start_step # AWP开始步数
self.adv_step = adv_step # AWP步数
self.backup = {} # 参数存储备份字典
self.backup_eps = {} # 参数扰动范围存储备份字典
self.scaler = scaler # 梯度缩放器
def attack_backward(self, batch, epoch):
if (self.adv_lr == 0) or (epoch < self.start_step):
return None
self._save() # 备份参数
for i in range(self.adv_step): # AWP步数
self._attack_step() # 对抗攻击
with autocast(enabled = CFG.apex):
input_ids = batch[0].to(CFG.device) # input_ids
attention_mask = batch[1].to(CFG.device) # attention_mask
labels = batch[2].to(CFG.device) # labels
output = self.model(input_ids, attention_mask, labels=labels) # 对抗训练
adv_loss = output.loss # 平均loss
self.optimizer.zero_grad() # 梯度清零
self.scaler.scale(adv_loss).backward() # 反向传播
self._restore() # 恢复参数
def _attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
#相比于fgm,AWP扰动的大小是考虑到了梯度和data两个层面的正则项
norm1 = torch.norm(param.grad) # 正规化梯度
norm2 = torch.norm(param.data.detach()) # 正规化参数
if norm1 != 0 and not torch.isnan(norm1):
r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e) # 梯度乘以对抗学习率,再除以梯度的正规化,再乘以参数的正规化
param.data.add_(r_at)
# 保证参数在扰动范围之内
param.data = torch.min(
torch.max(param.data, self.backup_eps[name][0]), #下界
self.backup_eps[name][1] #上界
)
def _save(self):
'''
1. 保存备份原参数,以便恢复
2. 添加参数的扰动
'''
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
if name not in self.backup:
self.backup[name] = param.data.clone() # 保存备份原参数
grad_eps = self.adv_eps * param.abs().detach() # 扰动范围
self.backup_eps[name] = ( #限制扰动的上下界
self.backup[name] - grad_eps,
self.backup[name] + grad_eps,
)
def _restore(self,):
'''
恢复原来的参数
'''
for name, param in self.model.named_parameters():
if name in self.backup:
param.data = self.backup[name]
self.backup = {}
self.backup_eps = {}
def train_fn_awp(train_loader, model, optimizer, epoch, scheduler, device):
model.train()
scaler = GradScaler(enabled = CFG.apex)
losses = AverageMeter()
start = end = time.time()
global_step = 0
awp = AWP(model, optimizer, adv_lr = CFG.adv_lr, adv_eps = CFG.adv_eps, start_step = CFG.start_awp_epoch, scaler = scaler)
for step, batch in enumerate(train_loader):
label = batch[2].to(device)
mask = batch[1].to(device)
input_ids = batch[0].to(device)
batch_size = label.size(0)
with autocast(enabled = CFG.apex): #f16精度计算,也可以不要这个
output = model(input_ids, mask, labels=label)
loss = output.loss
losses.update(loss.item(), batch_size)
scaler.scale(loss).backward()
if epoch >= CFG.start_awp_epoch:
awp.attack_backward(batch, epoch) # AWP对抗训练
# 梯度下降,更新参数(是等到正常的loss和带有扰动的loss叠加起来后才进行梯度更新)
scaler.step(optimizer) # 更新 optimizer
scaler.update() # 更新 scaler
scheduler.step()
#optimizer.zero_grad()
global_step += 1
end = time.time()
if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
print('Epoch: [{0}][{1}/{2}] '
'Elapsed {remain:s} '
'Loss: {loss.val:.4f}({loss.avg:.4f}) '
'LR: {lr:.8f} '
.format(epoch + 1, step, len(train_loader),
remain=timeSince(start, float(step + 1) / len(train_loader)),
loss=losses,
lr=scheduler.get_lr()[0]))
return losses.avg
对抗训练:FGM和AWP
最新推荐文章于 2025-03-17 20:49:33 发布