模型训练-Tricks-提升鲁棒性(2):SWA(随机权重平均/Stochastic Weight Averaging)、EMA(指数移动平均/Exponential Moving Average)

kaggle比赛中,不管是目标检测任务、语义分割任务中,经常能看到SWA(Stochastic Weight Averaging)和EMA(Exponential Moving Average)的身影。

一、SWA(随机权重平均)

SWA随机权重平均:在优化的末期取k个优化轨迹上的checkpoints,平均他们的权重,得到最终的网络权重,这样就会使得最终的权重位于flat曲面更中心的位置,缓解权重震荡问题,获得一个更加平滑的解,相比于传统训练有更泛化的解。

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
    for batch_idx, (data, target) in enumerate(train_loader):   
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        # 在反向传播前要手动将梯度清零
        optimizer.zero_grad()
        output = model(data)
        #计算losss
        loss = train_criterion(output, targets)
        # 反向传播求解梯度
        loss.backward()
        optimizer.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']   
    swa_model.update_parameters(model)
    swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

上面的代码展示了SWA的主要代码,实现的步骤:

1、定义SGD优化器。

2、定义SWA。

3、定义SWALR,调整模型的学习率。

4、开始训练,等待训练完成。

5、在每个epoch中更新模型的参数,更新学习率。

6、等待训练完成后,更新BN层的参数。


论文链接:https://arxiv.org/abs/1803.05407.pdf

官方代码:https://github.com/timgaripov/swa


1、步骤

1.给定超参数:

  • 循环周期c,代表训练c步就使用SWA进行一次权重平均
  • 学习率 α1,α2 ,即周期学习率的上界和下界,论文的实验使用的周期性学习率如下图

2.然后,按照正常的SGD标准流程进行训练,每训练c步,就平均一次权重

3.最后,使用平均的权重 wSWA 权重进行推理。

2、代码

import torch
import torch.nn as nn
 
 
def apply_swa(model: nn.Module,
              checkpoint_list: list,
              weight_list: list,
              strict: bool = True):
    """
    :param model:
    :param checkpoint_list: 要进行swa的模型路径列表
    :param weight_list: 每个模型对应的权重
    :param strict: 输入模型权重与checkpoint是否需要完全匹配
    :return:
    """
 
    checkpoint_tensor_list = [torch.load(f, map_location='cpu') for f in checkpoint_list]
 
    for name, param in model.named_parameters():
        try:
            param.data = sum([ckpt['model'][name] * w for ckpt, w in zip(checkpoint_tensor_list, weight_list)])
        except KeyError:
            if strict:
                raise KeyError(f"Can't match '{name}' from checkpoint")
            else:
                print(f"Can't match '{name}' from checkpoint")
 
    return model

二、EMA(Exponential Moving Average)【指数移动平均/移动权重平均】【提高测试指标并增加模型鲁棒】

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。

一句话总结权重滑动平均/指数滑动平均(Exponential Moving Average)就是:

Copy一份模型所有权重(记为Weights)的备份(记为EMA_weights),训练过程中每次更新权重时同时也对EMA_weights进行滑动平均更新,训练阶段结束后用EMA_weights替换模型权重进行预测。

具体地,EMA的超参decay一般设为接近1的数,从而保证每次EMA_weights的更新都很稳定。每batch更新流程为:

Weights=Weights+LR*Grad; (模型正常的梯度下降)

EMA_weights=EMA_weights*decay+(1-decay)*Weights; (根据新weight更新EMA_weights)

需要知道训练阶段无关EMA_weights,它只在测试阶段时导入进行预测。


EMA为什么有效


PyTorch实现

class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()    # 保存ema参数

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

# 初始化
ema = EMA(model, 0.999)
ema.register()

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    ema.update()

# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
    ema.apply_shadow()
    # evaluate
    ema.restore()



SWA实战:使用SWA进行微调,提高模型的泛化_swa真能提升性能吗_AI浩的博客-CSDN博客

【读】领域泛化 - SWA - 知乎

模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解_模型权重平均_iioSnail的博客-CSDN博客

常用训练tricks,提升你模型的鲁棒性

AI简报-模型集成 SAM 和SWA_深度学习_AIWeker_InfoQ写作社区




【提分trick】SWA(随机权重平均)和EMA(指数移动平均)_ema swa_zy_destiny的博客-CSDN博客

指数移动平均(EMA)的原理及PyTorch实现_ema 移动平均值数学原理_枫林扬的博客-CSDN博客

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

机器学习模型性能提升技巧:指数加权平均(EMA)_mikelkl的博客-CSDN博客

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值