同上一篇全连接神经网络实现手写数字识别,此文记录了直观测试模式的代码。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import cv2
from torch.autograd import Variable
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 设置超参数
num_epochs = 5
output_size = 10
batch_size = 100
learning_rate = 0.001
# MNIST 数据集下载
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
# 数据集加载
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
#2个卷积层的神经网络
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), #输入1通道,输出16通道,其实代表卷积核的个数为16
nn.BatchNorm2d(16), #输入1通道,输出16通道,其实代表卷积核的个数为16
nn.ReLU(), #激励函数处理
nn.MaxPool2d(kernel_size=2, stride=2)) #最大池化,降采样 2x2 步长为2
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, output_size)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1) #将输出7*7*32拉成一个张量,size(0),返回行数,view(行数,-1),reshape成多少行数,列数模糊控制不管。
out = self.fc(out)
return out
model = ConvNet().to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# 测试模型
model.eval() #把模型设置成验证模式
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1) ##data是一个以两个张量为元素的列表
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# 保存模型
torch.save(model.state_dict(), 'model.pkl')
#
X_test, y_test = next(iter(test_loader))
inputs = Variable(X_test)
pred = model(inputs)
_, pred = torch.max(pred, 1)
print("Predict Label is:", (i for i in pred))
print("Real Label is :", [i for i in y_test])
img = torchvision.utils.make_grid(X_test)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)
Test Accuracy of the model on the 10000 test images: 99.01 %
Predict Label is: <generator object <genexpr> at 0x000002A02B024138>
Real Label is : [tensor(7), tensor(2), tensor(1), tensor(0), tensor(4), tensor(1), tensor(4), tensor(9), tensor(5), tensor(9), tensor(0), tensor(6), tensor(9), tensor(0), tensor(1), tensor(5), tensor(9), tensor(7), tensor(3), tensor(4), tensor(9), tensor(6), tensor(6), tensor(5), tensor(4), tensor(0), tensor(7), tensor(4), tensor(0), tensor(1), tensor(3), tensor(1), tensor(3), tensor(4), tensor(7), tensor(2), tensor(7), tensor(1), tensor(2), tensor(1), tensor(1), tensor(7), tensor(4), tensor(2), tensor(3), tensor(5), tensor(1), tensor(2), tensor(4), tensor(4), tensor(6), tensor(3), tensor(5), tensor(5), tensor(6), tensor(0), tensor(4), tensor(1), tensor(9), tensor(5), tensor(7), tensor(8), tensor(9), tensor(3), tensor(7), tensor(4), tensor(6), tensor(4), tensor(3), tensor(0), tensor(7), tensor(0), tensor(2), tensor(9), tensor(1), tensor(7), tensor(3), tensor(2), tensor(9), tensor(7), tensor(7), tensor(6), tensor(2), tensor(7), tensor(8), tensor(4), tensor(7), tensor(3), tensor(6), tensor(1), tensor(3), tensor(6), tensor(9), tensor(3), tensor(1), tensor(4), tensor(1), tensor(7), tensor(6), tensor(9)]