刚接触深度学习的小伙伴们,是不是经常听说 MNIST 数据集和 PyTorch 框架?今天就带大家从零开始,用 PyTorch 实现 MNIST 手写数字识别,轻松迈出深度学习实践的第一步!
一、MNIST 数据集:深度学习界的 “Hello World”
MNIST 数据集就像是深度学习领域的 “新手村”,里面包含了 6 万张手写数字训练图片和 1 万张测试图片,每张图片都是 28×28 像素的灰度图像,对应的数字标签是 0 - 9。就好比是一个装满数字 “小卡片” 的百宝箱,我们的任务就是教会计算机 “看懂” 这些卡片上的数字。
二、PyTorch 基础库导入
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
这段代码是在引入我们需要的工具包。torch是 PyTorch 的核心库,就像搭建房屋的砖块;DataLoader是数据加载器,帮我们把数据分批处理;torchvision里的transforms和MNIST,一个用来转换数据格式,一个用来获取 MNIST 数据集;matplotlib.pyplot则是绘图工具,能帮我们直观看到预测结果;nn和nn.functional用于搭建神经网络模型和定义激活函数等操作。
三、搭建神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 64)
self.fc4 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
这里定义了一个名为Net的神经网络类,继承自nn.Module。__init__函数是初始化操作,fc1到fc4是全连接层,全连接层就像是一个 “信息加工厂”,将输入数据进行变换。比如fc1把 28×28(784)个像素点组成的数据转换为 64 个特征。forward函数定义了数据的前向传播过程,F.relu是激活函数,它就像一个 “开关”,让神经网络具备了学习非线性关系的能力,能让数据在不同层之间更好地传递和处理。
四、数据加载与预处理
def get_data_loader():
train_data = MNIST(root="mnist_data", train=True,
transform=transforms.ToTensor(), download=True)
test_data = MNIST(root="mnist_data", train=False,
transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)
return train_loader, test_loader
get_data_loader函数负责获取 MNIST 数据集。MNIST函数从指定路径(root)下载数据,transforms.ToTensor()将图片转换为 PyTorch 能处理的张量格式。DataLoader则把数据打包成一批一批的,batch_size=64表示每批有 64 张图片,shuffle=True让训练数据每次都打乱顺序,这样能让模型学习得更好。就像把小卡片分成一叠叠,每次训练随机抽取一叠,避免模型 “记住” 固定顺序。
五、模型评估函数
def evaluate(test_loader, net):
net.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.view(-1, 28*28)
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
evaluate函数用来评估模型的准确率。net.eval()将模型设置为评估模式,with torch.no_grad()表示在这个过程中不计算梯度,节省计算资源。遍历测试数据,把图片数据整理成合适格式后输入模型,得到输出结果,用torch.max找到概率最大的类别作为预测结果,最后计算预测正确的比例。
六、主函数:模型训练与测试
def main():
train_loader, test_loader = get_data_loader()
net = Net()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
print("初始准确率:", evaluate(test_loader, net))
for epoch in range(5):
net.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.view(-1, 28*28)
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{5}], Batch [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
running_loss = 0.0
accuracy = evaluate(test_loader, net)
print(f'Epoch {epoch+1}, 测试准确率: {accuracy:.4f}')
net.eval()
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, (images, labels) in enumerate(test_loader, 1):
if i > 5: break
with torch.no_grad():
outputs = net(images[0].view(-1, 28*28))
_, predicted = torch.max(outputs, 1)
axes[i-1].imshow(images[0].view(28, 28), cmap='gray')
axes[i-1].set_title(f'预测: {predicted.item()}')
axes[i-1].axis('off')
plt.tight_layout()
plt.show()
在main函数里,先获取数据加载器,创建模型,定义损失函数nn.CrossEntropyLoss()(它结合了 Softmax 和交叉熵损失计算)和优化器torch.optim.Adam(用来更新模型参数)。然后进入训练循环,epoch表示训练的轮数,每轮中遍历训练数据,通过前向传播、计算损失、反向传播和参数更新,不断调整模型参数。训练过程中打印损失值,每轮结束后评估模型准确率。最后可视化 5 张测试图片的预测结果,直观看到模型的识别效果。
if __name__ == "__main__":
main()
这行代码确保只有直接运行脚本时才执行main函数,避免被其他脚本导入时意外执行。
通过以上步骤,我们就完成了 MNIST 手写数字识别模型的搭建、训练和测试。希望这篇博客能帮助小白们理解深度学习的基本流程,快动手试试,开启你的 AI 探索之旅吧!如果在实践过程中有任何问题,欢迎在评论区交流~
彩蛋:
点赞+收藏+关注一键三联+评论“冲冲冲”可免费获取下面的项目实战大礼包!