import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
# 定义转换
transform = transforms.Compose([transforms.ToTensor()])
# 下载和加载训练数据和测试数据
train_data = MNIST(root='sdata', train=True, download=True, transform=transform)
test_data = MNIST(root='sdata', train=False, download=True, transform=transform)
# 创建数据加载器
batch_size = 100
train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_iter = DataLoader(test_data, batch_size=batch_size, shuffle=False)
# LeNet 网络定义
class LeNet(torch.nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, 6, 5),
torch.nn.Sigmoid(),
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(6, 16, 5),
torch.nn.Sigmoid(),
torch.nn.MaxPool2d(2, 2)
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(16 * 4 * 4, 120),
torch.nn.Sigmoid(),
torch.nn.Linear(120, 84),
torch.nn.Sigmoid(),
torch.nn.Linear(84, 10)
)
def forward(self, img):
feature = self.conv(img)
output = self.fc(feature.view(img.shape[0], -1))
return output
# 实例化网络
net = LeNet()
print(net)
# 训练函数定义
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
net = net.to(device)
loss = torch.nn.CrossEntropyLoss()
for i in range(num_epochs):
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
for step, (batch_x, batch_y) in enumerate(train_iter):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
y_hat = net(batch_x)
loss_ = loss(y_hat, batch_y)
optimizer.zero_grad()
loss_.backward()
optimizer.step()
train_l_sum += loss_.item()
train_acc_sum += (y_hat.argmax(dim=1) == batch_y).sum().item()
n += batch_y.shape[0]
batch_count += 1
print('epoch %d, loss %.4f, train acc %.3f, time %.1f sec'
% (i + 1, train_l_sum / batch_count, train_acc_sum / n, time.time() - start))
# 准备训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# 训练网络
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
# 保存和加载模型
torch.save(net.state_dict(), 'mnist.pth')
net.load_state_dict(torch.load('mnist.pth'))
# 测试集准确度
net.eval()
test_acc_sum ,n = 0, 0
for test_x, test_y in test_iter:
test_x = test_x.to(device)
test_y = test_y.to(device)
y_hat = net(test_x)
test_acc_sum += (y_hat.argmax(dim=1) == test_y).sum().item()
n += test_y.shape[0]
print('test acc %.3f' % (test_acc_sum / n))
# 显示一张测试图片和它的预测结果
test_x, test_y = next(iter(test_iter))
test_x = test_x.to(device)
test_y = test_y.to(device)
y_hat = net(test_x)
plt.imshow(test_x[0].squeeze().cpu().numpy(), cmap='gray')
print('True label:', test_y[0].item(), 'Predicted label:', y_hat[0].argmax(dim=0).item())
plt.show()