Pytorch 中的 torch.optim.swa_utils.AverageModel() 及其原理总结

1 背景知识

在了解 torch.optim.swa_utils.AverageModel() 前, 我们先了解以下 SWA(随机加权平均)

1.1 SWA

SWA 全称 : Stochastic Weight Averaging,

  • SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果

  • 随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点, 经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失

  • SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好

2 AverageModel() 介绍

AveragedModel 类用于计算SWA模型的权重。可以通过运行以下命令创建一个averaged model:

from torch.optim.swa_utils import AverageModel
swa_model = AverageModel(model)

这里的模型Model可以是任意的torch.nn.Module对象。swa_model将跟踪模型参数的运行平均值。要更新这些平均值,你可以使用update_parameters()函数:

swa_model.update_parameters(model)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值