pytorch : Stochastic Weight Averaging理解和用法

SWA has been proposed in Averaging Weights Leads to Wider Optima and Better Generalization.

SGD倾向于收敛到loss的平坦的区域,由于权重空间的维度比较高,平坦区域的大部分都处于边界,SGD通常只会走到这些平稳区域的边界。SWA通过平均多个SGD的权重参数,使其能够达到平坦区域的中心,从而得到更优的解。就相当于在一个最优解的附近,梯度都很小了,迭代可能一之在最优解附近震荡,通过平均最优解周围的参数,可以更接近最优解,利用平均值中和误差;

第一种实现方法:就是最后几次的平均

swa_model = AveragedModel(model)
swa_model.update_parameters(model) # 最后来一次就行,平均一下就可

第二种:一边平均,一边学习率变化

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

注意事项:

权重更新完了之后,需要来一下这个,对batchnormalization

torch.optim.swa_utils.update_bn(loader, swa_model

BN层训练过程中计算激活神经元的统计信息,而SWA平均的权重在训练过程中是不会用来预测的,所以当权重更新之后,BN层相对应的统计信息仍然是之前权重的。为了计算激活值的统计信息,只需要在训练结束之后对训练数据前向传播一次即可。

自定义平均:

默认情况下,torch.optim.swa_utils.AveragedModel计算参数的运行平均值,但也可以将自定义平均值函数与avg_fn参数一起使用。在以下示例中,ema_模型计算指数移动平均数。

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

参考文章:

1、https://liumin.blog.csdn.net/article/details/113128343 (随机梯度平均)

2、https://zhuanlan.zhihu.com/p/122504469

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值