【模型提分tricks】Adversarial Weight Perturbation(AWP)对抗训练

在如今AI遍及各行各业的情况下,现在不论是搞科研还是做比赛,一个非常重要的问题就是提升模型的robust,让训练出来的模型能更好的泛化到一个从未见过的测试集上,以此减小线上和线下的gap。

对抗训练 Adversarial training

我们知道模型训练是一个ERM(经验风险最小化,Empirical Risk Minimization)的过程,而对抗训练就是为了增强模型的抗干扰能力。

实现

经典训练

for step, batch in enumerate(train_loader):
    inputs, labels = batch
    
    # 将模型的参数梯度初始化为0
    optimizer.zero_grad()
    
    # forward + backward + optimize
    predicts = model(inputs)          # 前向传播计算预测值
    loss = loss_fn(predicts, labels)  # 计算当前损失
    loss.backward()       # 反向传播计算梯度
	loss.backward()
    optimizer.step()                  # 更新所有参数 

加入AWP训练

AWP


class AWP:
    """
    Implements weighted adverserial perturbation
    adapted from: https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook
    """

    def __init__(self, model, optimizer, adv_param="weight", adv_lr=1, adv_eps=0.0001):
        self.model = model
        self.optimizer = optimizer
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.backup = {}
        self.backup_eps = {}

    def attack_backward(self, inputs, labels):
        if self.adv_lr == 0:
            return
        self._save()
        self._attack_step()

        y_preds = self.model(inputs)

        adv_loss = self.criterion(y_preds, labels)
        self.optimizer.zero_grad()
        return adv_loss

    def _attack_step(self):
        e = 1e-6
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                norm1 = torch.norm(param.grad)
                norm2 = torch.norm(param.data.detach())
                if norm1 != 0 and not torch.isnan(norm1):
                    # 在损失函数之前获得梯度
                    r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                    param.data.add_(r_at)
                    param.data = torch.min(
                        torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
                    )

    def _save(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                if name not in self.backup:
                    self.backup[name] = param.data.clone()
                    grad_eps = self.adv_eps * param.abs().detach()
                    self.backup_eps[name] = (
                        self.backup[name] - grad_eps,
                        self.backup[name] + grad_eps,
                    )

    def _restore(self,):
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}
        self.backup_eps = {}

参数Args:

  • adv_param (str): 要攻击的layer name,一般攻击第一层 或者全部weight参数效果较好

  • adv_lr (float): 攻击步长,这个参数相对难调节,如果只攻击第一层embedding,一般用1比较好,全部参数用0.1比较好。

  • adv_eps (float): 参数扰动最大幅度限制,范围(0~ +∞),一般设置(0,1)之间相对合理一点。

  • start_epoch (int): (0~ +∞)什么时候开始扰动,默认是0,如果效果不好可以调节值模型收敛一半的时候再开始攻击。

"""
    使用AWP的训练过程
"""
# 初始化AWP
awp = AWP(model, loss_fn, optimizer, adv_lr=awp_lr, adv_eps=awp_eps)

for step, batch in enumerate(train_loader):
    inputs, labels = batch
    
    # 将模型的参数梯度初始化为0
    optimizer.zero_grad()
    
    # forward + backward + optimize
    predicts = model(inputs)          # 前向传播计算预测值
    loss = loss_fn(predicts, labels)  # 计算当前损失
    loss.backward()       # 反向传播计算梯度
    # 指定从第几个epoch开启awp,一般先让模型学习到一定程度之后
    if awp_start >= epoch:
        loss = awp.attack_backward(inputs, labels)
        loss.backward()
        awp._restore()                    # 恢复到awp之前的model
    optimizer.step()                  # 更新所有参数 

注意:使用AWP训练时间大概是原来的两倍

参考

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

落难Coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值