EMA滑动平均训练方式

1. EMA 介绍

首先该类实现, 使用timm ==0.6.11 版本;

Exponential Moving Average (EMA) for models in PyTorch.
目的:它旨在维护模型状态字典的移动平均值,包括参数和缓冲区。该技术通常用于训练方案,其中权重的平滑版本对于最佳性能至关重要。

1.1 v1 版本


class ModelEma:
    """ Model Exponential Moving Average (DEPRECATED)

    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This version is deprecated, it does not work with scripted models. Will be removed eventually.

    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

    A smoothed version of the weights is necessary for some training schemes to perform well.
    E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
    RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
    smoothing of weights to match results. Pay attention to the decay constant you are using
    relative to your update count per epoch.

    To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
    disable validation of the EMA weights. Validation will have to be done manually in a separate
    process, or after the training stops converging.

    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)


Methods:方法:

__init__:通过创建所提供模型的副本、设置衰减率和设备放置来初始化 EMA 模型。模型设置为评估模式,并且其梯度被禁用。

_load_checkpoint :加载 EMA 模型的检查点。它处理由 DataParallel 包装器引起的状态字典命名约定中的潜在差异。

update
通过计算原始模型参数和当前 EMA 参数的加权平均值来更新 EMA 参数。

Features:特征

  1. 可以为模型及其 EMA 对应项指定不同的设备。
  2. 处理由于 DataParallel 包装器导致的状态字典键不匹配。
  3. 由于与脚本模型不兼容v1版本被弃用

1.2 v2 版本

import logging
from collections import OrderedDict
from copy import deepcopy

import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEmaV2(nn.Module):
    """ Model Exponential Moving Average V2

    Keep a moving average of everything in the model state_dict (parameters and buffers).
    V2 of this module is simpler, it does not match params/buffers based on name but simply
    iterates in order. It works with torchscript (JIT of full model).

    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

    A smoothed version of the weights is necessary for some training schemes to perform well.
    E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
    RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
    smoothing of weights to match results. Pay attention to the decay constant you are using
    relative to your update count per epoch.

    To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
    disable validation of the EMA weights. Validation will have to be done manually in a separate
    process, or after the training stops converging.

    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model): # 使用衰减率更新 EMA 参数
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):  # 直接将 EMA 参数设置为与提供的模型参数相同。
        self._update(model, update_fn=lambda e, m: m)

EmaV2版本:与 ModelEma 类似,但实现更简单。它还维护模型状态字典的移动平均值,并设计为与 torchscript(完整模型的 JIT)配合使用。

Methods:方法:

__init__:与 ModelEma 类似,但添加了对 super() 的调用来初始化 nn.Module 基类。

_update :更新 EMA 参数的辅助函数,以自定义更新函数作为参数。

update :使用衰减率更新 EMA 参数。

set :直接将 EMA 参数设置为与提供的模型参数相同。

Features:特征:

  1. 比 ModelEma 更简单、更直接的实现。
  2. 与torchscipt兼容。
  3. 根据参数的顺序而不是名称来匹配参数。
  • v1 版本与 v2版本之间的差异
    Differences差异:
  1. 设计复杂性: ModelEmaV2 更简单、更直接,避免了按名称匹配参数。

  2. 兼容性: ModelEmaV2 与 torchscript 兼容,与 ModelEma 不同。

  3. .参数匹配: ModelEma 按名称匹配参数和缓冲区,而 ModelEmaV2 根据参数和顺序进行匹配。

  4. 版本控制和用例: ModelEma 已被弃用,并且对于较新的训练方案(尤其是需要脚本的训练方案)而言不太受欢迎。

  5. 这两个类本质上用于相同的目的,但采用不同的方法,使得 ModelEmaV2 更适合利用脚本的现代 PyTorch 工作流程。

2. 使用方法

与 ModelEma 相比,在训练过程中使用 ModelEmaV2 涉及的方法略有不同。以下是有关如何将 ModelEmaV2 合并到训练循环中的指南,以及有关衰减参数的作用和预训练权重的使用的说明。

要在训练过程中使用 ModelEma V2 ,您应该将其集成到现有的训练循环中。以下是有关如何执行此操作的分步指南:

由于v1版本被弃用, 所以这里介绍使用 V2 版本;

2.1 使用步骤

2.1.1. 初始化ema 类

  • 初始化:
    先定义自己的模型后,
    在初始化(或者实例化) ModelEmaV2时,将模型作为参数传入, 根据自己的训练策略设置 decay 参数, 可以先设置0.9, 然后设置0.5 的方式, 来确定自己的训练策略应该使用0.9 还是0.1;
model = YourModel()  # Replace with your model
ema = ModelEmaV2(model, decay=0.9999)
  • 设备配置:如果使用 GPU 等特定设备,请确保您的模型和 EMA 模型都移至该设备。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
ema.module.to(device)

2.1.2 训练阶段

这里需要注意的是, 在训练阶段, 调用的模型仍是原始的自定的self.model

在模型完成损失反向传播, 以及参数更新之后, 才会将此时的模型传入到 ema 中,调用Ema 中的updata()函数,完成对参数的滑动平均更新,

即Ema在训练阶段的调用情况,是在模型完成反向传播,以及参数更新之后。

           for i, (spec,cof,label) in enumerate(tqdm(self.train_data_loader,  desc=' training process')):
                spec_data, cof_data, label = spec.cuda().float(), cof.cuda().float(), label.long().cuda()

                model_out = self.net(spec_data, cof_data)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # note, 在模型完成反向传播之后使用, 这里更新ema 的模型
                self.ema.update(self.net)

在重声一遍吧,
这里需要注意到的是 ,需要在每个反向传播 更新之后,才去更新EMA 模型;

for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        ema.update(model)
        

2.1.3 推理阶段

即在获取EMA更新的权重之后,EMA 模型的参数权重, 真正使用他的地方是在推理阶段。
由于滑动平均后的权重参数,更适合预测阶段,所以真正使用 Ema更新的权重,是在推理阶段

  • 验证:使用EMA更新后的权重参数,进行验证。
ema.module.eval()  # Set EMA model to evaluation mode
with torch.no_grad():
    for batch in validation_dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = ema.module(inputs)  # Use EMA model for predictions
        # Compute validation metrics

2.1.4 参数保存

  • 检查点:保存常规模型和 EMA 模型的状态字典。
torch.save({
    'model_state_dict': model.state_dict(),
    'ema_state_dict': ema.module.state_dict(),
    # ... other states like optimizer, epoch, etc.
}, 'checkpoint.pth')

  • 恢复训练:要从检查点恢复,请加载两个状态字典。
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
ema.module.load_state_dict(checkpoint['ema_state_dict'])
# Load other states

2.2 decay 参数的影响

ModelEmaV2 中的衰减参数起着至关重要的作用:

它确定移动平均线中当前模型参数相对于历史参数的权重。

  • 较高的衰减值(接近 1)赋予历史参数更大的权重,从而导致 EMA 模型权重的更新更平滑且更慢。
  • 较低的衰减值使 EMA 模型的权重对模型参数的近期变化更加敏感。

衰减值的选择取决于您的训练动态和训练步骤总数。常见的做法是从高衰减开始,然后随着时间的推移逐渐减少。

  • decay 参数;
    较高的衰减值(接近 1):当衰减参数设置为接近 1 时,EMA 模型会为较旧的(历史)参数赋予更多权重,而为最近更新的参数赋予较少权重。这使得 EMA 权重随着时间的推移变得更加平滑和更加稳定。平均权重响应新数据的变化更慢,这有利于减少噪声更新的影响。

较低的衰减值(远离 1):较低的衰减值导致 EMA 模型更加重视最近的模型更新。这使得 EMA 权重不太平滑,因为它们对模型参数的最新变化更加敏感。虽然这可以使 EMA 权重对数据的新趋势更加敏感,但也使它们更容易受到噪音和突然变化的影响。

总而言之,较高的衰减参数(接近 1)通过赋予历史数据更多权重来提高 EMA 模型权重的平滑度,从而导致权重更稳定但响应性较差。相反,较低的衰减值会降低平滑度,使权重对最近的变化更加敏感,但会牺牲稳定性。适当衰减值的选择取决于训练过程的具体要求和数据的性质。

使用 ModelEmaV2 时,在初始化 ModelEmaV2 之前将预训练的权重加载到原始模型中可能会很有帮助,特别是当您正在进行微调或有特定的起点时。

2.3 预训练权重

使用预先训练的权重:

  • 使用 ModelEmaV2 时,在初始化 ModelEmaV2 之前将预训练的权重加载到原始模型中可能会很有帮助,特别是当您正在进行微调或有特定的起点时。

  • 然后,EMA 模型将从这些权重的平滑版本开始,这可以导致更快的收敛和可能更好的最终性能,特别是在微调场景中。

  • 但是,如果您从头开始训练,则使用没有预训练权重的模型初始化 ModelEmaV2 也可以。 EMA 模型将随着训练的进展进行调整。

  • 总之, ModelEmaV2 用于维持模型权重的更平滑、更稳定的版本,这对于实现最佳性能至关重要,特别是在训练的后期阶段或微调场景中。衰减参数是控制应用平滑程度的关键。使用 ModelEmaV2 时,预训练权重可能很有用,但它们并不是绝对必要的,特别是在从头开始训练的场景中。

2.4 bug 问题

遇到的错误 RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment:

表示 ModelEmaV2 初始化中的 deepcopy 操作存在问题。当尝试在 PyTorch 中深度复制具有一定复杂性或特定类型的层或参数的模型时,通常会出现此问题。

  • 检查不可复制的层或参数:PyTorch 模型中的某些自定义层或参数可能不支持深度复制。如果您的模型包含此类层,请考虑修改模型以仅使用深度复制兼容的层。

  • 更新 PyTorch 版本:确保您使用的是最新版本的 PyTorch。有时,此类问题会在新版本中得到解决。

解决方法:自定义深度复制方法:此函数将手动将每个参数和缓冲区从原始模型复制到新模型。可以编写自定义函数来创建模型的副本,而不是使用 deepcopy 。即将原始的__init__() 初始化过程中, self.module 不使用 deepcopy()函数。

替换成如下方式拷贝:

def custom_deepcopy(model):
    model_copy = type(model)()  # Create a new instance of the model's class
    model_copy.load_state_dict(model.state_dict())  # Copy parameters and buffers
    return model_copy

self.ema = ModelEmaV2(custom_deepcopy(self.net), decay=0.9999)

并且需要将原始 __init__() 初始化过程中, self.module 不使用 deepcopy()函数,

    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        # self.module =    deepcopy(model)
        self.module = model
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)
<think>好的,我现在需要回答用户关于滑动平均的问题。用户之前提到了卡尔曼滤波在相位噪声补偿中的应用,现在转向滑动平均,可能需要结合数据处理和机器学习中的相关应用。 首先,滑动平均的定义是什么?滑动平均(Moving Average)是一种用来分析数据点的技术,通过创建一系列平均值来平滑短期波动,突出长期趋势。常见的类型有简单滑动平均(SMA)、加权滑动平均(WMA)和指数加权滑动平均EMA)。 接下来,用户想知道滑动平均在数据处理和机器学习中的应用。数据处理方面,滑动平均常用于时间序列分析,比如去除噪声,平滑数据,识别趋势。在股票价格分析中常用。机器学习中,可能用于数据预处理,比如平滑特征,或者在模型训练中作为优化技巧,比如EMA用于梯度下降,稳定训练过程。 用户提供的引用中,引用[3]提到了数据标准化,但滑动平均更多是平滑处理,可能和标准化不同。不过,可以提到数据预处理的不同方法。引用[2]涉及模型评估指标,可能可以联系到使用滑动平均来平滑评估指标,比如计算训练损失的平均值以便观察趋势。引用[1]关于数据增强,虽然滑动平均本身不是数据增强,但可能在生成平滑数据时辅助增强。 需要注意用户之前的讨论涉及卡尔曼滤波,滑动平均和卡尔曼滤波都是滤波技术,可以简要比较两者的区别,比如滑动平均是无状态或简单加权,而卡尔曼滤波是动态系统状态估计,更复杂。 还要考虑应用场景的具体例子,比如在深度学习中,EMA用于模型权重的平均,提升测试性能。在实时数据处理中,滑动平均用于实时平滑数据流。 可能需要检查是否有遗漏的应用点,比如在卷积神经网络中的滤波器是否类似滑动平均,但可能属于不同概念。需要准确区分。 最后,确保回答结构清晰,定义、应用分点,结合用户之前的兴趣点(如卡尔曼滤波),并生成相关问题。</think>### 滑动平均的定义及其应用详解 #### 1. **滑动平均的基本定义** **滑动平均(Moving Average)** 是一种通过计算数据序列中连续窗口内数据的平均值来平滑短期波动、突出长期趋势或周期性的统计方法。其核心思想是**用局部数据的均值代替原始值**,从而降低噪声干扰。根据权重分配方式不同,主要分为以下类型: - **简单滑动平均(SMA)**: 对窗口内所有数据赋予相同权重,计算公式为: $$ \text{SMA}_k = \frac{1}{n} \sum_{i=k-n+1}^{k} x_i $$ 其中$n$为窗口长度,$x_i$为第$i$个数据点。 - **加权滑动平均(WMA)**: 为窗口内不同位置的数据分配递减权重(如线性递减),公式为: $$ \text{WMA}_k = \frac{\sum_{i=1}^{n} w_i x_{k-n+i}}{\sum_{i=1}^{n} w_i} $$ 例如,窗口长度为3时,权重可设为$[0.5, 0.3, 0.2]$。 - **指数加权滑动平均EMA)**: 为历史数据赋予指数衰减权重,越新的数据权重越高,公式为: $$ \text{EMA}_k = \alpha x_k + (1-\alpha) \text{EMA}_{k-1} $$ 其中$\alpha \in (0,1)$为平滑因子,控制权重衰减速度(常用$\alpha=2/(n+1)$与窗口长度$n$关联)[^1]。 --- #### 2. **在数据处理中的应用场景** ##### **2.1 时间序列去噪与趋势分析** - **股票价格分析**:通过SMA或EMA平滑每日股价波动,识别长期趋势(如20日均线、60日均线)[^1]。 - **传感器数据处理**:对温度、加速度等传感器采集的噪声数据做平滑处理,提高信号可读性。 - **示例代码(Python)**: ```python import pandas as pd # 计算简单滑动平均(窗口长度=5) data = pd.Series([2, 4, 6, 8, 10, 12, 14, 16]) sma = data.rolling(window=5).mean() ``` ##### **2.2 实时数据流处理** - **实时监控系统**:对实时产生的时序数据(如服务器CPU使用率)计算滑动平均,避免瞬时峰值引发误报警。 - **动态阈值调整**:结合滑动均值和标准差,动态设定异常检测阈值(如$ \mu_{\text{MA}} \pm 3\sigma_{\text{MA}} $)。 --- #### 3. **在机器学习中的核心应用** ##### **3.1 数据预处理** - **特征平滑**:对输入特征(如时序特征)进行滑动平均处理,减少噪声对模型训练的影响。例如,在销量预测中,对历史销量数据平滑后再作为模型输入[^3]。 - **标签平滑**:对分类任务中的标签做滑动平均处理,缓解类别不平衡问题(需谨慎使用)。 ##### **3.2 模型训练优化** - **梯度平滑**:在随机梯度下降(SGD)中,对梯度计算滑动平均,稳定优化过程。公式为: $$ g_{\text{EMA},t} = \beta g_{\text{EMA},t-1} + (1-\beta) g_t $$ 其中$g_t$为第$t$步的梯度,$\beta$通常取0.9或0.99。 - **模型权重EMA**: 在训练末期对模型权重计算指数滑动平均,提升测试性能(如TensorFlow的`ExponentialMovingAverage`)。最终测试时使用EMA权重而非原始权重,可减少权重震荡影响[^2]。 ##### **3.3 评估指标平滑** - **训练曲线可视化**:对训练损失、准确率等指标计算滑动平均,更清晰地观察收敛趋势。例如,在PyTorch中: ```python loss_ema = 0.9 * loss_ema + 0.1 * current_loss ``` --- #### 4. **滑动平均与卡尔曼滤波的对比** | **特性** | **滑动平均** | **卡尔曼滤波** | |------------------|--------------------------------------|--------------------------------| | **计算复杂度** | $O(1)$(EMA)或$O(n)$(SMA) | $O(m^3)$($m$为状态维度) | | **动态适应性** | 固定窗口或衰减因子,无状态更新机制 | 动态调整增益,显式建模过程噪声 | | **适用场景** | 静态或缓慢变化的数据流 | 非平稳动态系统状态估计 | | **噪声抑制** | 依赖窗口长度或平滑因子经验选择 | 基于$Q/R$参数理论优化 | --- #### 5. **参数选择与调优建议** - **窗口长度$n$**: - 过小:平滑效果弱,残留高频噪声。 - 过大:滞后严重,丢失细节特征。 - **经验公式**:$n \approx \text{数据周期长度}/2$(若存在周期性)。 - **EMA平滑因子$\alpha$**: - $\alpha$越大,对新数据响应越快,但平滑效果越弱。 - **自动调整策略**:根据数据方差动态调整$\alpha$(如$\alpha = 1/\sigma_t$)。 --- ### 相关问题 1. 如何通过滑动平均处理非平稳时间序列数据? 2. 指数加权滑动平均EMA)在深度学习优化器中如何具体实现? 3. 滑动平均与低通滤波在信号处理中的区别与联系是什么?
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值