分类 Fashion-MNIST 数据集
分类 Fashion-MNIST 数据集
现在轮到你来构建一个神经网络了。你将使用的是 Fashion-MNIST 数据集,这是 MNIST 数据集的替代品。对于神经网络而言,原始的 MNIST 数据集体量太小,因而你可以轻易达到 97% 以上的准确率。而 Fashion-MNIST 数据集是一组有关衣物的 28x28 灰阶图像。这个数据集比 MNIST 复杂得多,因此你能更好地判断神经网络的性能,它也更加接近你在现实世界中使用的数据集。
在这个 notebook 中,你将构建专属于你的神经网络。在大多数情况下,你可以直接复制粘贴第三部分的代码,但这样一来你很难学到知识。因此我们推荐你自己编写代码来运行程序,这十分重要。不过在完成这个任务时,你也可以随时查阅和参考之前的 notebook。
首先,我们通过 torchvision 来加载数据集。
import torch
from torchvision import datasets, transforms
import helper
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Download and load the training data
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
print('trainloader-', trainloader)
print('trainloader-iter-', iter(trainloader))
print('trainloader-iter-', iter(trainloader))
trainloader-
trainloader-iter-
trainloader-iter-
在这里,我们能看到其中一张图片。
image, label = next(iter(trainloader))
print('image0=', image[0,:])
print('label=', label)
image0= tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9843,
-1.0000, -1.0000, -1.0000, -0.6784, -1.0000, -1.0000, -1.0000,
-0.9922, -0.9922, -0.9922, -0.9922, -0.9686, -0.9843, -0.9922],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -0.9922, -0.9843, -1.0000, -1.0000,
-1.0000, -0.3961, 0.1294, 1.0000, 0.1373, -1.0000, -1.0000,
-0.9765, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -0.9608, -1.0000, -1.0000, -1.0000, -0.4431,
0.4431, 0.7569, 0.5686, 0.6706, 0.8275, -0.1608, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -0.4275, 0.3020, 0.7412,
0.6471, 0.6706, 0.7725, 0.6549, 0.7725, 1.0000, 0.6471,
-0.2784, -1.0000, -1.0000, -0.8902, 1.0000, 0.8745, -0.5451],
[-0.9843, -1.0000, -1.0000, -0.9765, -0.9529, -0.9843, -1.0000,
-1.0000, -1.0000, -0.0667, 0.5686, 0.6549, 0.8118, 0.4902,
0.8980, 0.7490, 0.7647, 0.8510, 0.7882, 0.7804, 0.8588,
1.0000, 0.9608, 0.6706, 1.0000, 0.7804, 0.8588, 0.2235],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-0.2941, 0.5137, 0.6078, 0.6706, 0.6078, 0.8510, 0.8118,
0.7804, 0.8196, 0.8510, 0.7725, 0.7569, 0.8588, 0.8196,
0.7882, 0.8353, 0.9529, 0.7569, 0.6235, 0.7804, 0.2235],
[-1.0000, -1.0000, -1.0000, -0.9765, -0.5059, -0.0353, 0.6235,
0.7020, 0.7412, 0.6784, 0.8431, 0.8353, 0.7569, 0.7725,
0.7804, 0.7804, 0.7882, 0.8196, 0.8588, 0.8196, 0.8431,
0.8588, 0.8118, 0.8431, 0.8118, 0.8431, 0.8353, 0.5529],
[-1.0000, -0.3882, 0.4196, 0.6157, 0.7176, 0.7176, 0.6627,
0.6392, 0.6863, 0.7255, 0.6471, 0.6549, 0.7255, 0.6941,
0.8196, 0.7333, 0.6627, 0.7647, 0.7333, 0.7412, 0.7725,
0.7725, 0.7882, 0.7569, 0.7020, 0.6392, 0.7725, 0.6157],
[-0.9451, 0.1529, 0.5216, 0.6235, 0.6784, 0.7333, 0.7490,
0.7490, 0.7569, 0.7490, 0.7882, 0.7882, 0.8275, 0.8353,
0.8588, 0.8980, 0.8824, 0.9686, 0.9765, 0.9922, 0.9922,
1.0000, 1.0000, 0.9216, 1.0000, 0.9686, 1.0000, 0.5373],
[-0.9294, -0.6078, -0.4745, -0.2314, 0.0431, 0.2941, 0.4980,
0.6078, 0.7725, 0.8118, 0.8039, 0.8118, 0.7804, 0.7804,
0.7647, 0.7176, 0.6706, 0.6314, 0.5373, 0.4745, 0.4353,
0.4275, 0.3725, 0.2000, 0.1529, 0.1059, 0.0353, -0.4353],
[-0.8196, -0.6314, -0.7882, -0.8275, -0.8039, -0.8196, -0.7490,
-0.6392, -0.3569, 0.0353, -0.0353, -0.0353, -0.0510, -0.1137,
-0.1529, -0.1608, -0.2000, -0.2157, -0.2078, -0.2392, -0.2471,
-0.2471, -0.2549, -0.2784, -0.2706, -0.2863, -0.3882, -0.5922],
[-1.0000, -0.8745, -0.6941, -0.6627, -0.7098, -0.7725, -0.7882,
-0.9216, -0.8196, -0.4588, -0.5137, -0.5294, -0.5529, -0.6314,
-0.6784, -0.6706, -0.6863, -0.6863, -0.6627, -0.6941, -0.7020,
-0.6941, -0.7020, -0.7490, -0.7961, -0.7961, -0.7647, -0.8745],
[-1.0000, -1.0000, -1.0000, -1.0000, -0.9059, -0.9216, -0.8039,
-0.7098, -0.7647, -0.8824, -0.8588, -0.8588, -0.9059, -0.8196,
-0.7725, -0.8196, -0.8588, -0.8745, -0.9137, -0.9216, -0.9216,
-0.9059, -0.9137, -0.9451, -0.9294, -0.9529, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
label= tensor([ 7, 9, 9, 3, 6, 4, 8, 8, 7, 5, 4, 7, 5, 0,
7, 7, 2, 6, 8, 0, 2, 0, 2, 9, 7, 4, 3, 2,
8, 6, 9, 5, 5, 8, 3, 2, 6, 1, 0, 1, 9, 3,
6, 6, 3, 2, 9, 7, 3, 6, 9, 5, 2, 5, 6, 2,
5, 4, 7, 5, 8, 2, 4, 8])
# show images
helper.imshow(image[0,:]);
在加载数据之后,我们应该导入一些必要的包了。
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import helper
构建网络
在这里,你应该定义你的网络。如同 MNIST 数据集一样,这里的每张图片的像素为 28x28,共有 784 个像素点和 10 个类。你至少需要添加一个隐藏层。对于这些层,我们推荐你使用 ReLU 激活函数,并通过前向传播来返回 logits。层的数量和大小都由你来决定。
# TODO: Define your network architecture here
class Network(nn.Module):
def __init__(self):
# 初始化父类
super().__init__()
# Defining the layers, 200, 50, 10 units each
self.fc1 = nn.Linear(784, 200)
self.fc2 = nn.Linear(200, 50)
# Output layer, 10 units - one for each digit
self.fc3 = nn.Linear(50, 10)
def forward(self, x):
''' Forward pass through the network, returns the output logits '''
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
def predict(self, x):
''' This function for predicts classes by calculating the softmax '''
logits = self.forward(x)
return F.softmax(logits, dim=1)
训练网络
现在,你应该构建你的网络并训练它了。首先,你需要定义条件(比如 nn.CrossEntropyLoss)以及优化器(比如 optim.SGD 或 optim.Adam)。
接着,你需要编写训练代码。请记住,训练传播是一个十分简明的过程:
在网络中进行前向传播来获取 logits
使用 logits 来计算损失
使用 loss.backward() 在网络中进行后向传播来计算梯度
使用优化器执行一个学习步来更新权重
通过调整超参数(隐藏单元、学习速率等),你应该可以将训练损失控制在 0.4 以下。
# TODO: Create the network, define the criterion and optimizer
net = Network()
criterion = nn.CrossEntropyLoss()
# lr = learning rate
optimizer = optim.SGD(net.parameters(), lr=0.01)
# TODO: Train the network here
print('Initial weights - ', net.fc1.weight)
print('print trainloader- ', trainloader)
dataiter = iter(trainloader)
images, labels = dataiter.next()
images.resize_(64, 784)
# Create Variables for the inputs and targets
inputs = Variable(images)
targets = Variable(labels)
# Clear the gradients from all Variables
optimizer.zero_grad()
# Forward pass, then backward pass, then update weights
output = net.forward(inputs)
loss = criterion(output, targets)
loss.backward()
print('Gradient -', net.fc1.weight.grad)
optimizer.step()
Initial weights - Parameter containing:
tensor([[-6.5417e-03, -2.4802e-02, -1.5146e-02, ..., -1.2971e-02,
-1.6070e-02, 2.6286e-02],
[ 5.7576e-03, -1.0399e-02, 4.4446e-03, ..., -2.5082e-03,
-4.3815e-03, 1.6680e-02],
[-3.4878e-02, -2.0759e-02, 1.6003e-02, ..., 8.6531e-04,
-1.9558e-02, 2.1282e-02],
...,
[ 2.6933e-02, 2.3026e-02, -3.3443e-02, ..., -2.4827e-02,
-3.5710e-02, -6.9900e-03],
[ 3.5409e-03, -3.0244e-02, 9.5727e-03, ..., -1.0316e-02,
-1.9417e-02, 1.2862e-04],
[-3.3738e-03, -3.0613e-02, 1.6543e-02, ..., 1.7032e-02,
2.3136e-02, -1.5136e-02]])
print trainloader-
Gradient - tensor([[ 1.9819e-03, 1.9818e-03, 1.9818e-03, ..., 1.9833e-03,
1.9834e-03, 1.9818e-03],
[-1.0067e-03, -1.0061e-03, -1.0094e-03, ..., -1.0024e-03,
-1.0067e-03, -1.0061e-03],
[-1.8131e-03, -1.8103e-03, -1.8119e-03, ..., -1.7803e-03,
-1.8102e-03, -1.8103e-03],
...,
[ 5.4580e-04, 5.4580e-04, 5.4289e-04, ..., 5.5021e-04,
5.4945e-04, 5.4580e-04],
[ 1.3846e-04, 1.3887e-04, 1.3865e-04, ..., 1.3844e-04,
1.3593e-04, 1.3887e-04],
[-7.1335e-04, -7.1335e-04, -7.1335e-04, ..., -7.0985e-04,
-7.1335e-04, -7.1335e-04]])
# 实际训练
epochs = 5
steps = 0
running_loss = 0
print_every = 40
for e in range(epochs):
for images,labels in iter(trainloader):
steps += 1
# Flatten MNIST images into a 784 long vector
images.resize_(images.size()[0], 784)
# Wrap images and labels in Variables so we can calculate gradients
inputs = Variable(images)
targets = Variable(labels)
optimizer.zero_grad()
output = net.forward(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
running_loss += loss.data[0]
if steps % print_every == 0:
# Test accuracy
accuracy = 0
for ii, (images, labels) in enumerate(testloader):
images = images.resize_(images.size()[0], 784)
inputs = Variable(images, volatile=True)
predicted = net.predict(inputs).data
equality = (labels == predicted.max(1)[1])
accuracy += equality.type_as(torch.FloatTensor()).mean()
print("Epoch: {}/{}".format(e+1, epochs),
"Loss: {:.4f}".format(running_loss/print_every),
"Test accuracy: {:.4f}".format(accuracy/(ii+1)))
running_loss = 0
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:23: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:31: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
Epoch: 1/5 Loss: 2.2696 Test accuracy: 0.2491
Epoch: 1/5 Loss: 2.1425 Test accuracy: 0.4081
Epoch: 1/5 Loss: 1.9469 Test accuracy: 0.4537
Epoch: 1/5 Loss: 1.7148 Test accuracy: 0.5521
Epoch: 1/5 Loss: 1.4649 Test accuracy: 0.6200
Epoch: 1/5 Loss: 1.2773 Test accuracy: 0.6135
Epoch: 1/5 Loss: 1.1315 Test accuracy: 0.6715
Epoch: 1/5 Loss: 1.0100 Test accuracy: 0.6924
Epoch: 1/5 Loss: 0.9619 Test accuracy: 0.7131
Epoch: 1/5 Loss: 0.8913 Test accuracy: 0.7143
Epoch: 1/5 Loss: 0.8496 Test accuracy: 0.7067
Epoch: 1/5 Loss: 0.8156 Test accuracy: 0.7199
Epoch: 1/5 Loss: 0.7611 Test accuracy: 0.7322
Epoch: 1/5 Loss: 0.7401 Test accuracy: 0.7388
Epoch: 1/5 Loss: 0.7371 Test accuracy: 0.7433
Epoch: 1/5 Loss: 0.7028 Test accuracy: 0.7266
Epoch: 1/5 Loss: 0.7179 Test accuracy: 0.7473
Epoch: 1/5 Loss: 0.6788 Test accuracy: 0.7420
Epoch: 1/5 Loss: 0.6824 Test accuracy: 0.7507
Epoch: 1/5 Loss: 0.6545 Test accuracy: 0.7603
Epoch: 1/5 Loss: 0.6479 Test accuracy: 0.7602
Epoch: 1/5 Loss: 0.6444 Test accuracy: 0.7579
Epoch: 1/5 Loss: 0.6262 Test accuracy: 0.7652
Epoch: 2/5 Loss: 0.6556 Test accuracy: 0.7607
Epoch: 2/5 Loss: 0.6199 Test accuracy: 0.7681
Epoch: 2/5 Loss: 0.5855 Test accuracy: 0.7698
Epoch: 2/5 Loss: 0.6072 Test accuracy: 0.7705
Epoch: 2/5 Loss: 0.5928 Test accuracy: 0.7747
Epoch: 2/5 Loss: 0.6003 Test accuracy: 0.7717
Epoch: 2/5 Loss: 0.5904 Test accuracy: 0.7787
Epoch: 2/5 Loss: 0.5780 Test accuracy: 0.7787
Epoch: 2/5 Loss: 0.5626 Test accuracy: 0.7831
Epoch: 2/5 Loss: 0.5845 Test accuracy: 0.7889
Epoch: 2/5 Loss: 0.5900 Test accuracy: 0.7892
Epoch: 2/5 Loss: 0.5319 Test accuracy: 0.7854
Epoch: 2/5 Loss: 0.5623 Test accuracy: 0.7902
Epoch: 2/5 Loss: 0.5561 Test accuracy: 0.7937
Epoch: 2/5 Loss: 0.5558 Test accuracy: 0.7892
Epoch: 2/5 Loss: 0.5508 Test accuracy: 0.7948
Epoch: 2/5 Loss: 0.5395 Test accuracy: 0.7956
Epoch: 2/5 Loss: 0.5332 Test accuracy: 0.7911
Epoch: 2/5 Loss: 0.5422 Test accuracy: 0.8009
Epoch: 2/5 Loss: 0.5310 Test accuracy: 0.8002
Epoch: 2/5 Loss: 0.5016 Test accuracy: 0.8025
Epoch: 2/5 Loss: 0.5284 Test accuracy: 0.8026
Epoch: 2/5 Loss: 0.5133 Test accuracy: 0.7999
Epoch: 3/5 Loss: 0.5076 Test accuracy: 0.8068
Epoch: 3/5 Loss: 0.4983 Test accuracy: 0.8050
Epoch: 3/5 Loss: 0.4966 Test accuracy: 0.8049
Epoch: 3/5 Loss: 0.4893 Test accuracy: 0.7986
Epoch: 3/5 Loss: 0.5068 Test accuracy: 0.8085
Epoch: 3/5 Loss: 0.5018 Test accuracy: 0.8138
Epoch: 3/5 Loss: 0.5217 Test accuracy: 0.8073
Epoch: 3/5 Loss: 0.5191 Test accuracy: 0.8148
Epoch: 3/5 Loss: 0.5063 Test accuracy: 0.8132
Epoch: 3/5 Loss: 0.4712 Test accuracy: 0.8124
Epoch: 3/5 Loss: 0.4820 Test accuracy: 0.8118
Epoch: 3/5 Loss: 0.4893 Test accuracy: 0.8147
Epoch: 3/5 Loss: 0.5146 Test accuracy: 0.8163
Epoch: 3/5 Loss: 0.5124 Test accuracy: 0.8161
Epoch: 3/5 Loss: 0.4974 Test accuracy: 0.8133
Epoch: 3/5 Loss: 0.5093 Test accuracy: 0.8194
Epoch: 3/5 Loss: 0.4760 Test accuracy: 0.8176
Epoch: 3/5 Loss: 0.4960 Test accuracy: 0.8195
Epoch: 3/5 Loss: 0.4649 Test accuracy: 0.8154
Epoch: 3/5 Loss: 0.4778 Test accuracy: 0.8169
Epoch: 3/5 Loss: 0.5091 Test accuracy: 0.8137
Epoch: 3/5 Loss: 0.4302 Test accuracy: 0.8176
Epoch: 3/5 Loss: 0.4675 Test accuracy: 0.8230
Epoch: 3/5 Loss: 0.4825 Test accuracy: 0.8195
Epoch: 4/5 Loss: 0.4620 Test accuracy: 0.8243
Epoch: 4/5 Loss: 0.4741 Test accuracy: 0.8249
Epoch: 4/5 Loss: 0.4340 Test accuracy: 0.8180
Epoch: 4/5 Loss: 0.4573 Test accuracy: 0.8262
Epoch: 4/5 Loss: 0.4699 Test accuracy: 0.8250
Epoch: 4/5 Loss: 0.4633 Test accuracy: 0.8204
Epoch: 4/5 Loss: 0.4966 Test accuracy: 0.8244
Epoch: 4/5 Loss: 0.4854 Test accuracy: 0.8264
Epoch: 4/5 Loss: 0.4844 Test accuracy: 0.8253
Epoch: 4/5 Loss: 0.4672 Test accuracy: 0.8208
Epoch: 4/5 Loss: 0.4508 Test accuracy: 0.8267
Epoch: 4/5 Loss: 0.4514 Test accuracy: 0.8280
Epoch: 4/5 Loss: 0.4458 Test accuracy: 0.8267
Epoch: 4/5 Loss: 0.4318 Test accuracy: 0.8271
Epoch: 4/5 Loss: 0.4639 Test accuracy: 0.8262
Epoch: 4/5 Loss: 0.4509 Test accuracy: 0.8305
Epoch: 4/5 Loss: 0.4320 Test accuracy: 0.8266
Epoch: 4/5 Loss: 0.4579 Test accuracy: 0.8284
Epoch: 4/5 Loss: 0.4521 Test accuracy: 0.8237
Epoch: 4/5 Loss: 0.4405 Test accuracy: 0.8318
Epoch: 4/5 Loss: 0.4559 Test accuracy: 0.8295
Epoch: 4/5 Loss: 0.4785 Test accuracy: 0.8279
Epoch: 4/5 Loss: 0.4291 Test accuracy: 0.8318
Epoch: 5/5 Loss: 0.4580 Test accuracy: 0.8288
Epoch: 5/5 Loss: 0.4441 Test accuracy: 0.8292
Epoch: 5/5 Loss: 0.4358 Test accuracy: 0.8358
Epoch: 5/5 Loss: 0.4435 Test accuracy: 0.8337
Epoch: 5/5 Loss: 0.4557 Test accuracy: 0.8332
Epoch: 5/5 Loss: 0.4531 Test accuracy: 0.8322
Epoch: 5/5 Loss: 0.4062 Test accuracy: 0.8346
Epoch: 5/5 Loss: 0.4480 Test accuracy: 0.8331
Epoch: 5/5 Loss: 0.4449 Test accuracy: 0.8346
Epoch: 5/5 Loss: 0.4486 Test accuracy: 0.8310
Epoch: 5/5 Loss: 0.4481 Test accuracy: 0.8369
Epoch: 5/5 Loss: 0.4624 Test accuracy: 0.8359
Epoch: 5/5 Loss: 0.4464 Test accuracy: 0.8340
Epoch: 5/5 Loss: 0.4372 Test accuracy: 0.8350
Epoch: 5/5 Loss: 0.4079 Test accuracy: 0.8349
Epoch: 5/5 Loss: 0.3984 Test accuracy: 0.8368
Epoch: 5/5 Loss: 0.4247 Test accuracy: 0.8350
Epoch: 5/5 Loss: 0.4390 Test accuracy: 0.8332
Epoch: 5/5 Loss: 0.4108 Test accuracy: 0.8367
Epoch: 5/5 Loss: 0.4279 Test accuracy: 0.8362
Epoch: 5/5 Loss: 0.4078 Test accuracy: 0.8381
Epoch: 5/5 Loss: 0.4241 Test accuracy: 0.8380
Epoch: 5/5 Loss: 0.4210 Test accuracy: 0.8366
Epoch: 5/5 Loss: 0.4117 Test accuracy: 0.8317
# Test out your network!
dataiter = iter(testloader)
images, labels = dataiter.next()
img = images[0]
# Convert 2D image to 1D vector
img = img.resize_(1, 784)
# TODO: Calculate the class probabilities (softmax) for img
ps = net.predict(Variable(img.resize_(1, 784)))
# Plot the image and probabilities
helper.view_classify(img.resize_(1, 28, 28), ps)
训练好神经网络之后,你应该希望保存这个网络以便下次加载,而不是重新训练。很明显,每次使用时都重新训练网络并不现实。在实际操作中,你将会在训练网络之后将模型保存,接着重新加载网络以进行训练或是预测。在下一部分,我将为你展示如何保存和加载训练好的模型。
为者常成,行者常至