代码示例:
在PyTorch中,你可以使用torchvision
库来方便地下载和加载MNIST数据集。以下是一个详细的示例,展示了如何使用PyTorch和torchvision
来下载、加载并使用MNIST数据集:
-
安装必要的库:
首先,确保你已经安装了
torch
和torchvision
。可以使用以下命令进行安装:pip install torch torchvision
-
下载和加载数据集:
使用
torchvision.datasets
中的MNIST
类来下载和加载数据集。同时,使用torch.utils.data.DataLoader
来创建一个数据加载器,以便于在训练过程中批量处理数据。import torch import torchvision import torchvision.transforms as transforms # 定义数据转换 transform = transforms.Compose([ transforms.ToTensor(), # 将图片转换为Tensor transforms.Normalize((0.5,), (0.5,)) # 归一化处理 ]) # 下载并加载训练集 trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) # 下载并加载测试集 testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) # 检查数据 dataiter = iter(trainloader) images, labels = dataiter.next() print(images.shape, labels.shape)
-
创建一个简单的神经网络模型:
使用
torch.nn
模块来定义一个简单的神经网络模型。import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28*28, 512) # 28*28是输入图片的像素数 self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 10) # 10是MNIST数据集的类别数 def forward(self, x): x = x.view(-1, 28*28) # 展平图片 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return F.log_softmax(x, dim=1) net = Net()
-
训练模型:
使用定义好的模型和数据加载器来训练模型。
import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 训练模型 for epoch in range(2): # 运行多个周期 running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: # 每2000个小批量打印一次训练状态 print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}') running_loss = 0.0 print('Finished Training')
-
测试模型:
在测试集上评估模型的性能。
correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
这个示例展示了如何使用PyTorch和torchvision
来下载、加载、训练和测试MNIST数据集。你可以根据需要调整模型结构和训练参数。
喜欢本文,请点赞、收藏和关注!