ema_mnist_blog

使用ModelEmaV2优化MNIST分类模型

在深度学习模型的训练过程中,参数波动可能会导致模型在测试集上的性能不稳定。为了解决这个问题,可以使用指数移动平均(EMA)技术来平滑参数的更新,从而获得更稳定的模型。本文将介绍如何在MNIST数据集上使用ModelEmaV2来优化分类模型,并分析其效果。

实验背景

MNIST数据集是一个经典的手写数字识别数据集,包含60,000张训练图像和10,000张测试图像。我们的目标是训练一个简单的神经网络模型来分类这些手写数字,并使用EMA技术来优化模型参数。

模型定义与EMA实现

首先,我们定义一个简单的全连接神经网络模型,并实现ModelEmaV2来进行EMA参数更新。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy

# 定义简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 定义 EMA 模型
class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.99, device='cpu'):
        super(ModelEmaV2, self).__init__()
        self.ema_model = copy.deepcopy(model).to(device)
        self.ema_model.eval()
        self.decay = decay
        self.device = device

    def update(self, model):
        with torch.no_grad():
            model_params = dict(model.named_parameters())
            ema_params = dict(self.ema_model.named_parameters())
            for k in model_params.keys():
                ema_params[k].mul_(self.decay).add_(model_params[k], alpha=1 - self.decay)

    def forward(self, x):
        return self.ema_model(x)

数据加载与预处理

我们使用torchvision库来加载和预处理MNIST数据集。

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

训练与评估

我们进行4个epoch的训练,并在每个epoch结束后评估模型和EMA模型的准确率。

# 训练和评估
num_epochs = 4
results = []

for epoch in range(num_epochs):
    model.train()
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # 更新 EMA 模型
        ema_model.update(model)
    
    # 计算每个epoch的准确率
    model.eval()
    ema_model.eval()
    correct = 0
    total = 0
    ema_correct = 0
    ema_total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            
            # 测试 EMA 模型
            ema_outputs = ema_model(inputs)
            _, ema_predicted = torch.max(ema_outputs.data, 1)
            ema_total += targets.size(0)
            ema_correct += (ema_predicted == targets).sum().item()
    
    normal_accuracy = 100 * correct / total
    ema_accuracy = 100 * ema_correct / ema_total
    lag = normal_accuracy - ema_accuracy
    
    results.append({
        'epoch': epoch + 1,
        'normal_accuracy': normal_accuracy,
        'ema_accuracy': ema_accuracy,
        'lag': lag
    })

results

实验结果分析

实验结果如下表所示:

EpochNormal Model AccuracyEMA Model AccuracyLag
191.0990.970.12
292.5492.460.08
393.5393.500.03
494.0394.13-0.10

从结果可以看出,在训练的前几轮,EMA模型的准确率稍微滞后于正常模型,但随着训练的进行,两者的准确率逐渐接近,甚至在第四轮时,EMA模型的准确率略高于正常模型。

结论

通过实验可以看出,EMA技术在一定程度上平滑了模型参数的波动,使得模型在测试集上的表现更加稳定。尽管在训练的初期EMA模型的准确率稍有滞后,但随着训练的进行,EMA模型的表现逐渐赶上并超过了正常模型。这表明EMA技术对于提高模型的稳定性和性能具有重要作用。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值