1. 先加载Mnist数据集
mnist_train = torchvision.datasets.MNIST('Mnist', train=True, download=True,
transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('Mnist', train=False, download=True,
transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)
shuffle = True
表示将数据集打乱
2. 定义模型结构
class Mnist_NN(nn.Module):
def __init__(self):
super(Mnist_NN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
return self.model(x)
3. 训练并保存模型
mnist_nn = Mnist_NN()
mnist_nn.to(device)
learning_rate = 0.001
optimizer = torch.optim.Adam(mnist_nn.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
epoch = 10
for i in range(epoch):
# 训练
mnist_nn.train()
total_train_loss = 0
total_train_accuracy = 0
print(f'第{i + 1}轮训练开始')
for x, y in train_loader:
x = x.to(device)
y = y.to(device)
y_pred = mnist_nn(x)
loss = loss_fn(y_pred, y)
total_train_loss += loss.item()
total_train_accuracy += (y_pred.argmax(1) == y).sum().item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(
f'第{i + 1}轮训练结束,平均损失为{total_train_loss / len(train_loader)},总准确率为{total_train_accuracy / len(mnist_train)}')
# 测试
mnist_nn.eval()
total_test_loss = 0
total_accuracy = 0
print(f'第{i + 1}轮测试开始')
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
y_pred = mnist_nn(x)
loss = loss_fn(y_pred, y)
total_test_loss += loss.item()
total_accuracy += (y_pred.argmax(1) == y).sum().item()
print(
f'第{i + 1}轮测试结束,平均损失为{total_test_loss / len(test_loader)},总准确率为{total_accuracy / len(mnist_test)}')
# 保存模型参数
torch.save(mnist_nn.state_dict(), f'mnist_train_{i + 1}.pth')
4. 测试自己手写的数字
import torch
import torchvision
from PIL import Image
from torch import nn
class Mnist_NN(nn.Module):
def __init__(self):
super(Mnist_NN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2),
# Flatten
nn.Flatten(),
# Linear
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
return self.model(x)
model = Mnist_NN()
model.load_state_dict(torch.load('mnist_train_10.pth', weights_only=True))
img_path = 'imgs/nine.png'
image = Image.open(img_path)
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((28, 28)),
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])
image = transform(image)
image = image.reshape(1, 1, 28, 28)
model.eval()
with torch.no_grad():
output = model(image)
# print(output)
idx = output.argmax(1).item()
print(idx)