使用ModelEmaV2优化MNIST分类模型
在深度学习模型的训练过程中,参数波动可能会导致模型在测试集上的性能不稳定。为了解决这个问题,可以使用指数移动平均(EMA)技术来平滑参数的更新,从而获得更稳定的模型。本文将介绍如何在MNIST数据集上使用ModelEmaV2来优化分类模型,并分析其效果。
实验背景
MNIST数据集是一个经典的手写数字识别数据集,包含60,000张训练图像和10,000张测试图像。我们的目标是训练一个简单的神经网络模型来分类这些手写数字,并使用EMA技术来优化模型参数。
模型定义与EMA实现
首先,我们定义一个简单的全连接神经网络模型,并实现ModelEmaV2来进行EMA参数更新。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy
# 定义简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(28*28, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = self.fc(x)
return x
# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 定义 EMA 模型
class ModelEmaV2(nn.Module):
def __init__(self, model, decay=0.99, device='cpu'):
super(ModelEmaV2, self).__init__()
self.ema_model = copy.deepcopy(model).to(device)
self.ema_model.eval()
self.decay = decay
self.device = device
def update(self, model):
with torch.no_grad():
model_params = dict(model.named_parameters())
ema_params = dict(self.ema_model.named_parameters())
for k in model_params.keys():
ema_params[k].mul_(self.decay).add_(model_params[k], alpha=1 - self.decay)
def forward(self, x):
return self.ema_model(x)
数据加载与预处理
我们使用torchvision
库来加载和预处理MNIST数据集。
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
训练与评估
我们进行4个epoch的训练,并在每个epoch结束后评估模型和EMA模型的准确率。
# 训练和评估
num_epochs = 4
results = []
for epoch in range(num_epochs):
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# 更新 EMA 模型
ema_model.update(model)
# 计算每个epoch的准确率
model.eval()
ema_model.eval()
correct = 0
total = 0
ema_correct = 0
ema_total = 0
with torch.no_grad():
for inputs, targets in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
# 测试 EMA 模型
ema_outputs = ema_model(inputs)
_, ema_predicted = torch.max(ema_outputs.data, 1)
ema_total += targets.size(0)
ema_correct += (ema_predicted == targets).sum().item()
normal_accuracy = 100 * correct / total
ema_accuracy = 100 * ema_correct / ema_total
lag = normal_accuracy - ema_accuracy
results.append({
'epoch': epoch + 1,
'normal_accuracy': normal_accuracy,
'ema_accuracy': ema_accuracy,
'lag': lag
})
results
实验结果分析
实验结果如下表所示:
Epoch | Normal Model Accuracy | EMA Model Accuracy | Lag |
---|---|---|---|
1 | 91.09 | 90.97 | 0.12 |
2 | 92.54 | 92.46 | 0.08 |
3 | 93.53 | 93.50 | 0.03 |
4 | 94.03 | 94.13 | -0.10 |
从结果可以看出,在训练的前几轮,EMA模型的准确率稍微滞后于正常模型,但随着训练的进行,两者的准确率逐渐接近,甚至在第四轮时,EMA模型的准确率略高于正常模型。
结论
通过实验可以看出,EMA技术在一定程度上平滑了模型参数的波动,使得模型在测试集上的表现更加稳定。尽管在训练的初期EMA模型的准确率稍有滞后,但随着训练的进行,EMA模型的表现逐渐赶上并超过了正常模型。这表明EMA技术对于提高模型的稳定性和性能具有重要作用。