模型训练-Tricks-02:SWA(Stochastic Weight Averaging)【随机权重平均】

论文链接:https://arxiv.org/abs/1803.05407.pdf

 

官方代码:https://github.com/timgaripov/swa

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
    for batch_idx, (data, target) in enumerate(train_loader):   
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        # 在反向传播前要手动将梯度清零
        optimizer.zero_grad()
        output = model(data)
        #计算losss
        loss = train_criterion(output, targets)
        # 反向传播求解梯度
        loss.backward()
        optimizer.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']   
    swa_model.update_parameters(model)
    swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

上面的代码展示了SWA的主要代码,实现的步骤:

1、定义SGD优化器。

2、定义SWA。

3、定义SWALR,调整模型的学习率。

4、开始训练,等待训练完成。

5、在每个epoch中更新模型的参数,更新学习率。

6、等待训练完成后,更新BN层的参数。




SWA实战:使用SWA进行微调,提高模型的泛化_swa真能提升性能吗_AI浩的博客-CSDN博客

【读】领域泛化 - SWA - 知乎

模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解_模型权重平均_iioSnail的博客-CSDN博客

常用训练tricks,提升你模型的鲁棒性

AI简报-模型集成 SAM 和SWA_深度学习_AIWeker_InfoQ写作社区

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值