1 背景知识
在了解 torch.optim.swa_utils.AverageModel() 前, 我们先了解以下 SWA(随机加权平均)
1.1 SWA
SWA 全称 : Stochastic Weight Averaging,
SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果
随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点, 经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失
SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好
2 AverageModel() 介绍
AveragedModel 类用于计算SWA模型的权重。可以通过运行以下命令创建一个averaged model:
from torch.optim.swa_utils import AverageModel
swa_model = AverageModel(model)
这里的模型Model可以是任意的torch.nn.Module对象。swa_model将跟踪模型参数的运行平均值。要更新这些平均值,你可以使用update_parameters()函数:
swa_model.update_parameters(model)