pytorch训练之EMA使用

原理

在深度学习中用于创建模型的指数移动平均(Exponential Moving Average,EMA)的副本。通常,指数移动平均是用来平滑模型的参数,以提高模型的泛化能力。

在这段代码中,model 是原始模型,deepcopy 函数用于创建模型的深层副本,避免共享内存。

在训练过程中,通常会使用 EMA 模型来获得更稳定的预测结果,而不是直接使用训练过程中的模型参数。这样可以减少模型在训练数据上的过拟合,并提高模型的泛化能力。

使用逻辑

 @torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        name = name.replace("module.", "")
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def main(args):
	    model = model.to(device)
	    ema = deepcopy(model).to(device)  # Create an EMA of the model for use after training
	    requires_grad(ema, False)
	    ...
    	# Prepare models for training:
	    update_ema(ema, model, decay=0)  # Ensure EMA is initialized with synced weights
	    model.train()  # important! This enables embedding dropout for classifier-free guidance
	    ema.eval()  # EMA model should always be in eval mode
    	...
	    for epoch in range(args.epochs):
	        if accelerator.is_main_process:
	            logger.info(f"Beginning epoch {epoch}...")
	        for x, y in loader:
	        		...
	                opt.step()
            		update_ema(ema, model)
            		...
            		if save:
	            		checkpoint = {
	                        "model": model.module.state_dict(),
	                        "ema": ema.state_dict(),
	                        "opt": opt.state_dict(),
	                        "args": args
	                    }
	                    checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
	                    torch.save(checkpoint, checkpoint_path)

权重平均(SWA 和 EMA)

torch.optim.swa_utils实现了随机权重平均(SWA)和指数移动平均(EMA)。特别是,torch.optim.swa_utils.AveragedModel类实现了 SWA 和 EMA 模型,torch.optim.swa_utils.SWALR实现了 SWA 学习率调度程序,torch.optim.swa_utils.update_bn()是一个实用函数,用于在训练结束时更新 SWA/EMA 批归一化统计数据。

SWA 已经在Averaging Weights Leads to Wider Optima and Better Generalization中提出。

EMA 是一种广泛知晓的技术,通过减少所需的权重更新次数来减少训练时间。它是Polyak 平均的一种变体,但是使用指数权重而不是在迭代中使用相等权重。

构建平均模型

AveragedModel 类用于计算 SWA 或 EMA 模型的权重。

您可以通过运行以下命令创建一个 SWA 平均模型:

>>> averaged_model = AveragedModel(model) 

通过指定multi_avg_fn参数来构建 EMA 模型,如下所示:

>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay)) 

衰减是一个介于 0 和 1 之间的参数,控制平均参数衰减的速度。如果未提供给get_ema_multi_avg_fn,则默认值为 0.999。

get_ema_multi_avg_fn返回一个函数,该函数将以下 EMA 方程应用于权重:

W t + 1 EMA = α W t EMA + ( 1 − α ) W t model W^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t Wt+1EMA=αWtEMA+(1α)Wtmodel

其中 alpha 是 EMA 衰减。

在这里,模型model可以是任意torch.nn.Module对象。averaged_model将跟踪model的参数的运行平均值。要更新这些平均值,您应该在 optimizer.step()之后使用update_parameters()函数:

>>> averaged_model.update_parameters(model) 

对于 SWA 和 EMA,这个调用通常在 optimizer step()之后立即执行。在 SWA 的情况下,通常在训练开始时跳过一些步骤。

自定义平均策略

默认情况下,torch.optim.swa_utils.AveragedModel计算您提供的参数的运行平均值,但您也可以使用avg_fnmulti_avg_fn参数使用自定义平均函数:

  • avg_fn允许定义一个操作在每个参数元组(平均参数,模型参数)上,并应返回新的平均参数。

  • multi_avg_fn允许定义更高效的操作,同时作用于参数列表的元组(平均参数列表,模型参数列表),例如使用torch._foreach*函数。此函数必须原地更新平均参数。

在以下示例中,ema_model使用avg_fn参数计算指数移动平均值:

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg) 

在以下示例中,ema_model使用更高效的multi_avg_fn参数计算指数移动平均值:

>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9)) 

SWA 学习率调度

通常,在 SWA 中,学习率设置为一个较高的恒定值。SWALR是一个学习率调度程序,它将学习率退火到一个固定值,然后保持恒定。例如,以下代码创建一个调度程序,它在每个参数组内将学习率从初始值线性退火到 0.05,共 5 个时期:

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) 

您还可以通过设置anneal_strategy="cos"来使用余弦退火到固定值,而不是线性退火。

处理批归一化

update_bn()是一个实用函数,允许在训练结束时计算给定数据加载器loader上 SWA 模型的批归一化统计信息:

>>> torch.optim.swa_utils.update_bn(loader, swa_model) 

update_bn()swa_model应用于数据加载器中的每个元素,并计算模型中每个批归一化层的激活统计信息。

警告

update_bn()假设数据加载器loader中的每个批次都是张量或张量列表,其中第一个元素是应用于网络swa_model的张量。如果您的数据加载器具有不同的结构,您可以通过在数据集的每个元素上使用swa_model进行前向传递来更新swa_model的批归一化统计信息。

SWA示例

在下面的示例中,swa_model是累积权重平均值的 SWA 模型。我们总共训练模型 300 个时期,并切换到 SWA 学习率计划,并开始在第 160 个时期收集参数的 SWA 平均值:

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input) 

EMA示例

在下面的示例中,ema_model是 EMA 模型,它累积权重的指数衰减平均值,衰减率为 0.999。我们总共训练模型 300 个时期,并立即开始收集 EMA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>>             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>           ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input) 

参考

  1. https://pytorch.org/docs/stable/optim.html#weight-averaging-swa-and-ema
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值