SWA(随机权重平均) for Pytorch

pytorch1.6中加入了随机权重平均(SWA)的api,使用起来更加方便了。

一.什么是Stochastic Weight Averaging(SWA)

SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果。

在这里插入图片描述
随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点,如上述左图中,经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失。

二.SWA与SGD的对比

从上面图中,可以发现,SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好。下面得这张曲线图也可以看出两者的差异。
在这里插入图片描述

三.SWA大致的使用流程(pytorch)

在这里插入图片描述
上图是一种SWA的例子。先使用恒定学习率进行训练,接着线性衰减学习率,最后在恒定学习率上,累加它们的权重(SWA)。在使用SWA之前,可以配合任意的优化器使用,如SGD、Adam等,直到训练到一定周期,开始记录训练的权重,当训练完成后,再将记录的权重进行加权平均。注意:在训练的过程中是不进行预测的(下面的代码可以看到),直到最后训练完后,再加权,然后才开始预测。

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...  # 定义数据加载器,优化器,模型,损失
swa_model = AveragedModel(model)  
scheduler = CosineAnnealingLR(optimizer, T_max=100) # 使用学习率策略(余弦退火)
swa_start = 5  # 设置SWA开始的周期,当epoch到该值的时候才开始记录模型的权重
swa_scheduler = SWALR(optimizer, swa_lr=0.05) # 当SWA开始的时候,使用的学习率策略

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

可以看到 使用了分为两个阶段的学习率策略,可以自由调整,SWALR中可以加入学习率策略的比如线性,余弦退火等。

torch.optim.swa_utils.update_bn(loader, swa_model)这一步的目的:

  • BN层没有在训练结束时计算激活统计信息。我们可以通过使用SWA模型对这些数据进行一次向前传递来计算这些统计数据。

四.Pytorch上使用swa的一些问题:

pytorch - swa_model模型保存的问题





参考链接:
https://blog.csdn.net/leviopku/article/details/84037946
https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

菊头蝙蝠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值