Pytorch中的Exponential Moving Average(EMA)

EMA(指数移动平均)是一种在模型训练中用于提升鲁棒性的技术,它结合了历史模型权重信息。在PyTorch中,ModelEMA类用于实现这一机制,通过对模型参数进行加权平均更新,特别是在每次迭代后,权重逐渐倾向于之前的值。这有助于稳定模型并在某些训练策略中提高性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

EMA介绍

EMA,指数移动平均,常用于更新模型参数、梯度等。

EMA的优点是能提升模型的鲁棒性(融合了之前的模型权重信息)

代码示例

下面以yolov7/utils/torch_utils.py代码为例:

class ModelEMA:
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        # Create EMA
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

ModelEMA类的__init__ 函数介绍

__init__ 函数的输入参数介绍

  • model:需要使用EMA策略更新参数的模型
  • decay:加权权重,默认为0.9999
  • updates:模型参数更新/迭代次数

__init__ 函数的初始化介绍

首先深拷贝一份模型

"""
创建EMA模型

model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间

is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model

"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

接着,初始化updates次数,若是从头开始训练,则该参数为0

self.updates = updates

最后,定义加权权重decay的计算公式(这里呈指数型变化),

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

ModelEMA类的update()函数介绍

如果调用该函数,则更新updates以及decay,

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

取出当前模型的参数,为更新EMA模型的参数做准备,

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

对EMA模型参数以及当前模型参数进行加权求和,作为EMA模型的新参数,

for k, v in self.ema.state_dict().items():
    if v.dtype.is_floating_point:
        v *= d
        v += (1. - d) * msd[k].detach()

参考文章

【代码解读】在pytorch中使用EMA - 知乎

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

以史为鉴!EMA在机器学习中的应用 - 知乎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chen_znn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值