用PyTorch搭建卷积神经网络实现MNIST手写数字识别
在深度学习领域,卷积神经网络(Convolutional Neural Network,简称CNN)是处理图像数据的强大工具。它通过卷积层、池化层和全连接层等组件,自动提取图像特征,在图像分类、目标检测等任务中表现卓越。本文将使用PyTorch框架,搭建一个CNN模型来实现MNIST手写数字识别,并详细解析每一步代码。
一、MNIST数据集介绍
MNIST数据集是深度学习领域经典的入门数据集,包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28x28像素,并且已经做了居中处理,这在一定程度上减少了预处理的工作量,能够加快模型的训练和运行速度。
二、环境准备与数据加载
2.1 导入必要的库
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
上述代码导入了PyTorch的核心库、神经网络模块、数据加载工具以及用于图像数据处理和数据集管理的库。
2.2 下载并加载数据集
training_data = datasets.MNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.MNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)
通过datasets.MNIST
函数分别下载训练集和测试集。root
参数指定数据下载的路径;train=True
表示下载训练集数据,train=False
则表示下载测试集数据;download=True
确保如果数据尚未下载,会自动进行下载;transform=ToTensor()
将图像数据转换为PyTorch能够处理的张量格式。
2.3 数据可视化
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):
img, label = training_data[i + 59000]
figure.add_subplot(3, 3, i + 1)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
这段代码使用matplotlib
库展示了训练数据集中的部分手写数字图像,通过plt.imshow
函数将张量格式的图像数据可视化,直观感受MNIST数据集的内容。
2.4 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
DataLoader
用于将数据集打包成批次,batch_size
参数指定每个批次包含的数据样本数量。将数据集分成批次进行训练,能够有效减少内存使用,并提高训练速度。
三、设备配置
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
这段代码检测当前设备是否支持GPU(CUDA)或苹果M系列芯片的GPU(MPS),如果都不支持,则使用CPU进行计算。后续模型和数据都会被移动到选定的设备上运行,以充分利用硬件资源加速训练。
四、定义卷积神经网络模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.Conv2d(32, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU()
)
self.out = nn.Linear(64 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output
在这个自定义的CNN
类中,继承自nn.Module
。__init__
方法中定义了网络的结构:
- 卷积层(
nn.Conv2d
):用于提取图像特征,通过设置in_channels
(输入通道数)、out_channels
(输出通道数,即卷积核个数)、kernel_size
(卷积核大小)、stride
(步长)和padding
(填充)等参数,控制卷积操作。 - 激活函数层(
nn.ReLU
):引入非线性,增强网络的表达能力。 - 池化层(
nn.MaxPool2d
):对特征图进行下采样,减少数据量和计算量,同时保留主要特征。 - 全连接层(
nn.Linear
):将卷积层和池化层提取的特征映射到输出类别(MNIST数据集中有10个数字类别)。
forward
方法定义了数据在网络中的前向传播路径,确保数据按照网络结构依次经过各层处理,最终输出预测结果。
五、训练与测试模型
5.1 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
nn.CrossEntropyLoss
是适用于多分类任务的交叉熵损失函数,用于计算模型预测结果与真实标签之间的差距。torch.optim.Adam
是一种常用的优化器,通过调整模型的参数(model.parameters()
)来最小化损失函数,lr
参数设置学习率,控制参数更新的步长。
5.2 训练函数
def train(dataloader, model, loss_fn, optimizer):
model.train()
batch_size_num = 1
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value = loss.item()
if batch_size_num % 100 == 0:
print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')
batch_size_num += 1
在训练函数中:
model.train()
将模型设置为训练模式,此时模型中的一些层(如Dropout层)会按照训练规则工作。- 遍历数据加载器中的每一个批次数据,将数据和标签移动到指定设备上。
- 通过模型进行预测,计算损失值。
- 使用
optimizer.zero_grad()
清零梯度,loss.backward()
进行反向传播计算梯度,optimizer.step()
根据梯度更新模型参数。 - 每隔100个批次,打印当前的损失值,以便观察训练过程中的损失变化。
5.3 测试函数
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')
测试函数中:
model.eval()
将模型设置为测试模式,关闭一些在训练过程中起作用但在测试时不需要的操作(如Dropout)。- 使用
with torch.no_grad()
上下文管理器,关闭梯度计算,因为在测试阶段不需要更新模型参数,这样可以节省计算资源。 - 遍历测试数据,计算每个批次的损失值并累加,同时统计预测正确的样本数量。
- 最后计算并打印测试集上的平均损失和准确率,评估模型的性能。
5.4 执行训练和测试
epoch = 9
for i in range(epoch):
print(i + 1)
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
通过设置训练轮数(epoch
),循环调用训练函数进行模型训练,每一轮训练结束后,调用测试函数评估模型在测试集上的性能。
六、总结
本文通过详细的代码解析,展示了如何使用PyTorch搭建一个简单的卷积神经网络来实现MNIST手写数字识别任务。从数据加载、模型定义,到训练和测试,每一个步骤都体现了CNN在图像分类任务中的核心思想和实现方法。通过不断调整模型结构、超参数等,还可以进一步提升模型的性能。卷积神经网络在图像领域的应用远不止于此,它在更复杂的图像任务和其他领域也有着广泛的应用前景,希望本文能为大家深入学习深度学习提供一个良好的开端。