手写数字图像识别是深度学习中的一个经典问题,可以使用卷积神经网络(CNN)来解决。在本教程中,我们将使用PyTorch和MNIST数据集来构建一个简单的CNN模型,以识别手写数字图像。
首先,我们需要导入必要的库,包括PyTorch、NumPy和Matplotlib:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
接下来,我们可以定义一些超参数,例如批次大小、学习率和训练轮数:
batch_size = 64
learning_rate = 0.001
num_epochs = 10
然后,我们可以加载MNIST数据集,并将其转换为张量:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=b