对抗训练:FGM和AWP

class AWP:
    def __init__(
        self,
        model,
        optimizer,
        adv_param = 'weight',
        adv_lr = 1,
        adv_eps = 0.2,
        start_step = 0,
        adv_step = 1,
        scaler = None
    ):
        self.model = model # 模型
        self.optimizer = optimizer # 优化器
        self.adv_param = adv_param # 对哪些参数进行对抗训练
        self.adv_lr = adv_lr # AWP学习率
        self.adv_eps = adv_eps # AWP扰动大小
        self.start_step = start_step # AWP开始步数
        self.adv_step = adv_step # AWP步数
        self.backup = {} # 参数存储备份字典
        self.backup_eps = {} # 参数扰动范围存储备份字典
        self.scaler = scaler # 梯度缩放器

    def attack_backward(self, batch, epoch):
        if (self.adv_lr == 0) or (epoch < self.start_step):
            return None

        self._save() # 备份参数
        for i in range(self.adv_step): # AWP步数
            self._attack_step()  # 对抗攻击
            with autocast(enabled = CFG.apex):
                input_ids = batch[0].to(CFG.device) # input_ids
                attention_mask = batch[1].to(CFG.device) # attention_mask
                labels = batch[2].to(CFG.device) # labels
                output = self.model(input_ids, attention_mask, labels=labels) # 对抗训练
                adv_loss = output.loss # 平均loss
            self.optimizer.zero_grad() # 梯度清零
            self.scaler.scale(adv_loss).backward() # 反向传播

        self._restore() # 恢复参数

    def _attack_step(self):
        e = 1e-6
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
            #相比于fgm,AWP扰动的大小是考虑到了梯度和data两个层面的正则项
                norm1 = torch.norm(param.grad) # 正规化梯度
                norm2 = torch.norm(param.data.detach()) # 正规化参数
                if norm1 != 0 and not torch.isnan(norm1):
                    r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e) # 梯度乘以对抗学习率,再除以梯度的正规化,再乘以参数的正规化
                    param.data.add_(r_at)
                    # 保证参数在扰动范围之内
                    param.data = torch.min(
                        torch.max(param.data, self.backup_eps[name][0]),   #下界
                        self.backup_eps[name][1]  #上界
                    )
                    
    def _save(self):
        '''
        1. 保存备份原参数,以便恢复
        2. 添加参数的扰动
        '''
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                if name not in self.backup:
                    self.backup[name] = param.data.clone() # 保存备份原参数
                    grad_eps = self.adv_eps * param.abs().detach() # 扰动范围
                    self.backup_eps[name] = (   #限制扰动的上下界
                        self.backup[name] - grad_eps,
                        self.backup[name] + grad_eps,
                    )

    def _restore(self,):
        '''
        恢复原来的参数
        '''
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}
        self.backup_eps = {}

def train_fn_awp(train_loader, model, optimizer, epoch, scheduler, device):
    model.train()
    scaler = GradScaler(enabled = CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    awp = AWP(model, optimizer, adv_lr = CFG.adv_lr, adv_eps = CFG.adv_eps, start_step = CFG.start_awp_epoch, scaler = scaler)
    for step, batch in enumerate(train_loader):
        label = batch[2].to(device)
        mask = batch[1].to(device)
        input_ids = batch[0].to(device)
        batch_size = label.size(0)
        with autocast(enabled = CFG.apex):   #f16精度计算,也可以不要这个
            output = model(input_ids, mask, labels=label)
        loss = output.loss
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        if epoch >= CFG.start_awp_epoch:
            awp.attack_backward(batch, epoch) # AWP对抗训练
        # 梯度下降,更新参数(是等到正常的loss和带有扰动的loss叠加起来后才进行梯度更新)
        scaler.step(optimizer) # 更新 optimizer
        scaler.update() # 更新 scaler
        scheduler.step()
        #optimizer.zero_grad()
        global_step += 1
    
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'LR: {lr:.8f}  '
                  .format(epoch + 1, step, len(train_loader),
                          remain=timeSince(start, float(step + 1) / len(train_loader)),
                          loss=losses,
                          lr=scheduler.get_lr()[0]))
    return losses.avg


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值