《python深度学习》笔记(十四):指数移动平均值EMA

定义

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重平均的方法。

 

作用

给W和b使用EMA,就是防止训练过程遇到异常数据或者随机跳跃(毕竟是随机批量,数据不确定)影响训练效果的,让W和b维持相对稳定

普通的参数权重相当于一直累积更新整个训练过程的梯度,使用EMA的参数权重相当于使用训练过程梯度的加权平均(刚开始的梯度权值很小)。由于刚开始训练不稳定,得到的梯度给更小的权值更为合理,所以EMA会有效。

啥时使用

EMA数据量小或者数据不稳定或者batch_size小的情况下尤其有用

比如回归问题的波士顿房价数据集,还有使用预训练的CNN中batch_size=20比较小。

或者看曲线,那种含噪声,波动很大,或者纵轴范围较大,数据方差较大的图像,为了使曲线变得平滑,更具可读性,所以使用EMA

 代码实现

class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
 
    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
 
    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
 
    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]
 
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}
 
# 初始化
ema = EMA(model, 0.999)
ema.register()
 
# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    ema.update()
 
# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
    ema.apply_shadow()
    # evaluate
ema.restore()
import matplotlib.pyplot as plt

points = [1, 5, 3, 9, 4]
def smooth_curve(points, factor=0.9):
    smoothed_points =[] # 数据点,权重系数
    for point in points:  # 遍历所有的数据点
        if smoothed_points:  # 如果列表中有数据,则执行下面步骤
            previous = smoothed_points[-1]
            smoothed_points.append(previous * factor + point * (1 - factor))
            #  指数移动平均值EMA,前一个数据点*加权系数+当前数据点*(1-加权系数)
        else:
            smoothed_points.append(point)  # append添加到列表中最后面
    return smoothed_points


results = smooth_curve(points)
print(results)
plt.plot(range(1, len(points) + 1), results)
plt.show()

[1, 1.4, 1.56, 2.304, 2.4736]

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序先锋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值