在kaggle比赛中,不管是目标检测任务、语义分割任务中,经常能看到SWA(Stochastic Weight Averaging)和EMA(Exponential Moving Average)的身影。
一、SWA(随机权重平均)
SWA随机权重平均:在优化的末期取k个优化轨迹上的checkpoints,平均他们的权重,得到最终的网络权重,这样就会使得最终的权重位于flat曲面更中心的位置,缓解权重震荡问题,获得一个更加平滑的解,相比于传统训练有更泛化的解。
SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。
from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
# 在反向传播前要手动将梯度清零
optimizer.zero_grad()
output = model(data)
#计算losss
loss = train_criterion(output, targets)
# 反向传播求解梯度
loss.backward()
optimizer.step()
lr = optimizer.state_dict()['param_groups'][0]['lr']
swa_model.update_parameters(model)
swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")
上面的代码展示了SWA的主要代码,实现的步骤:
1、定义SGD优化器。
2、定义SWA。
3、定义SWALR,调整模型的学习率。
4、开始训练,等待训练完成。
5、在每个epoch中更新模型的参数,更新学习率。
6、等待训练完成后,更新BN层的参数。
论文链接:https://arxiv.org/abs/1803.05407.pdf
官方代码:https://github.com/timgaripov/swa
1、步骤
1.给定超参数:
- 循环周期c,代表训练c步就使用SWA进行一次权重平均
- 学习率 α1,α2 ,即周期学习率的上界和下界,论文的实验使用的周期性学习率如下图
2.然后,按照正常的SGD标准流程进行训练,每训练c步,就平均一次权重
3.最后,使用平均的权重 wSWA 权重进行推理。
2、代码
import torch
import torch.nn as nn
def apply_swa(model: nn.Module,
checkpoint_list: list,
weight_list: list,
strict: bool = True):
"""
:param model:
:param checkpoint_list: 要进行swa的模型路径列表
:param weight_list: 每个模型对应的权重
:param strict: 输入模型权重与checkpoint是否需要完全匹配
:return:
"""
checkpoint_tensor_list = [torch.load(f, map_location='cpu') for f in checkpoint_list]
for name, param in model.named_parameters():
try:
param.data = sum([ckpt['model'][name] * w for ckpt, w in zip(checkpoint_tensor_list, weight_list)])
except KeyError:
if strict:
raise KeyError(f"Can't match '{name}' from checkpoint")
else:
print(f"Can't match '{name}' from checkpoint")
return model
二、EMA(Exponential Moving Average)【指数移动平均/移动权重平均】【提高测试指标并增加模型鲁棒】
指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
一句话总结权重滑动平均/指数滑动平均(Exponential Moving Average)就是:
Copy一份模型所有权重(记为Weights)的备份(记为EMA_weights),训练过程中每次更新权重时同时也对EMA_weights进行滑动平均更新,训练阶段结束后用EMA_weights替换模型权重进行预测。
具体地,EMA的超参decay一般设为接近1的数,从而保证每次EMA_weights的更新都很稳定。每batch更新流程为:
Weights=Weights+LR*Grad; (模型正常的梯度下降)
EMA_weights=EMA_weights*decay+(1-decay)*Weights; (根据新weight更新EMA_weights)
需要知道训练阶段无关EMA_weights,它只在测试阶段时导入进行预测。
EMA为什么有效
PyTorch实现
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone() # 保存ema参数
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
# 初始化
ema = EMA(model, 0.999)
ema.register()
# 训练过程中,更新完参数后,同步update shadow weights
def train():
optimizer.step()
ema.update()
# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
ema.apply_shadow()
# evaluate
ema.restore()
SWA实战:使用SWA进行微调,提高模型的泛化_swa真能提升性能吗_AI浩的博客-CSDN博客
AI简报-模型集成 SAM 和SWA_深度学习_AIWeker_InfoQ写作社区
【提分trick】SWA(随机权重平均)和EMA(指数移动平均)_ema swa_zy_destiny的博客-CSDN博客
指数移动平均(EMA)的原理及PyTorch实现_ema 移动平均值数学原理_枫林扬的博客-CSDN博客