mnist torch加载fashion_Python:PyTorch 分类 Fashion-MNIST 数据集 (七十八)

分类 Fashion-MNIST 数据集

分类 Fashion-MNIST 数据集

现在轮到你来构建一个神经网络了。你将使用的是 Fashion-MNIST 数据集,这是 MNIST 数据集的替代品。对于神经网络而言,原始的 MNIST 数据集体量太小,因而你可以轻易达到 97% 以上的准确率。而 Fashion-MNIST 数据集是一组有关衣物的 28x28 灰阶图像。这个数据集比 MNIST 复杂得多,因此你能更好地判断神经网络的性能,它也更加接近你在现实世界中使用的数据集。

在这个 notebook 中,你将构建专属于你的神经网络。在大多数情况下,你可以直接复制粘贴第三部分的代码,但这样一来你很难学到知识。因此我们推荐你自己编写代码来运行程序,这十分重要。不过在完成这个任务时,你也可以随时查阅和参考之前的 notebook。

首先,我们通过 torchvision 来加载数据集。

import torch

from torchvision import datasets, transforms

import helper

# Define a transform to normalize the data

transform = transforms.Compose([transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Download and load the training data

trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data

testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

print('trainloader-', trainloader)

print('trainloader-iter-', iter(trainloader))

print('trainloader-iter-', iter(trainloader))

trainloader-

trainloader-iter-

trainloader-iter-

在这里,我们能看到其中一张图片。

image, label = next(iter(trainloader))

print('image0=', image[0,:])

print('label=', label)

image0= tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9843,

-1.0000, -1.0000, -1.0000, -0.6784, -1.0000, -1.0000, -1.0000,

-0.9922, -0.9922, -0.9922, -0.9922, -0.9686, -0.9843, -0.9922],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -0.9922, -0.9843, -1.0000, -1.0000,

-1.0000, -0.3961, 0.1294, 1.0000, 0.1373, -1.0000, -1.0000,

-0.9765, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -0.9608, -1.0000, -1.0000, -1.0000, -0.4431,

0.4431, 0.7569, 0.5686, 0.6706, 0.8275, -0.1608, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -0.4275, 0.3020, 0.7412,

0.6471, 0.6706, 0.7725, 0.6549, 0.7725, 1.0000, 0.6471,

-0.2784, -1.0000, -1.0000, -0.8902, 1.0000, 0.8745, -0.5451],

[-0.9843, -1.0000, -1.0000, -0.9765, -0.9529, -0.9843, -1.0000,

-1.0000, -1.0000, -0.0667, 0.5686, 0.6549, 0.8118, 0.4902,

0.8980, 0.7490, 0.7647, 0.8510, 0.7882, 0.7804, 0.8588,

1.0000, 0.9608, 0.6706, 1.0000, 0.7804, 0.8588, 0.2235],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-0.2941, 0.5137, 0.6078, 0.6706, 0.6078, 0.8510, 0.8118,

0.7804, 0.8196, 0.8510, 0.7725, 0.7569, 0.8588, 0.8196,

0.7882, 0.8353, 0.9529, 0.7569, 0.6235, 0.7804, 0.2235],

[-1.0000, -1.0000, -1.0000, -0.9765, -0.5059, -0.0353, 0.6235,

0.7020, 0.7412, 0.6784, 0.8431, 0.8353, 0.7569, 0.7725,

0.7804, 0.7804, 0.7882, 0.8196, 0.8588, 0.8196, 0.8431,

0.8588, 0.8118, 0.8431, 0.8118, 0.8431, 0.8353, 0.5529],

[-1.0000, -0.3882, 0.4196, 0.6157, 0.7176, 0.7176, 0.6627,

0.6392, 0.6863, 0.7255, 0.6471, 0.6549, 0.7255, 0.6941,

0.8196, 0.7333, 0.6627, 0.7647, 0.7333, 0.7412, 0.7725,

0.7725, 0.7882, 0.7569, 0.7020, 0.6392, 0.7725, 0.6157],

[-0.9451, 0.1529, 0.5216, 0.6235, 0.6784, 0.7333, 0.7490,

0.7490, 0.7569, 0.7490, 0.7882, 0.7882, 0.8275, 0.8353,

0.8588, 0.8980, 0.8824, 0.9686, 0.9765, 0.9922, 0.9922,

1.0000, 1.0000, 0.9216, 1.0000, 0.9686, 1.0000, 0.5373],

[-0.9294, -0.6078, -0.4745, -0.2314, 0.0431, 0.2941, 0.4980,

0.6078, 0.7725, 0.8118, 0.8039, 0.8118, 0.7804, 0.7804,

0.7647, 0.7176, 0.6706, 0.6314, 0.5373, 0.4745, 0.4353,

0.4275, 0.3725, 0.2000, 0.1529, 0.1059, 0.0353, -0.4353],

[-0.8196, -0.6314, -0.7882, -0.8275, -0.8039, -0.8196, -0.7490,

-0.6392, -0.3569, 0.0353, -0.0353, -0.0353, -0.0510, -0.1137,

-0.1529, -0.1608, -0.2000, -0.2157, -0.2078, -0.2392, -0.2471,

-0.2471, -0.2549, -0.2784, -0.2706, -0.2863, -0.3882, -0.5922],

[-1.0000, -0.8745, -0.6941, -0.6627, -0.7098, -0.7725, -0.7882,

-0.9216, -0.8196, -0.4588, -0.5137, -0.5294, -0.5529, -0.6314,

-0.6784, -0.6706, -0.6863, -0.6863, -0.6627, -0.6941, -0.7020,

-0.6941, -0.7020, -0.7490, -0.7961, -0.7961, -0.7647, -0.8745],

[-1.0000, -1.0000, -1.0000, -1.0000, -0.9059, -0.9216, -0.8039,

-0.7098, -0.7647, -0.8824, -0.8588, -0.8588, -0.9059, -0.8196,

-0.7725, -0.8196, -0.8588, -0.8745, -0.9137, -0.9216, -0.9216,

-0.9059, -0.9137, -0.9451, -0.9294, -0.9529, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],

[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,

-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])

label= tensor([ 7, 9, 9, 3, 6, 4, 8, 8, 7, 5, 4, 7, 5, 0,

7, 7, 2, 6, 8, 0, 2, 0, 2, 9, 7, 4, 3, 2,

8, 6, 9, 5, 5, 8, 3, 2, 6, 1, 0, 1, 9, 3,

6, 6, 3, 2, 9, 7, 3, 6, 9, 5, 2, 5, 6, 2,

5, 4, 7, 5, 8, 2, 4, 8])

# show images

helper.imshow(image[0,:]);

在加载数据之后,我们应该导入一些必要的包了。

%matplotlib inline

%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import numpy as np

import time

import torch

from torch import nn

from torch import optim

import torch.nn.functional as F

from torch.autograd import Variable

from torchvision import datasets, transforms

import helper

构建网络

在这里,你应该定义你的网络。如同 MNIST 数据集一样,这里的每张图片的像素为 28x28,共有 784 个像素点和 10 个类。你至少需要添加一个隐藏层。对于这些层,我们推荐你使用 ReLU 激活函数,并通过前向传播来返回 logits。层的数量和大小都由你来决定。

# TODO: Define your network architecture here

class Network(nn.Module):

def __init__(self):

# 初始化父类

super().__init__()

# Defining the layers, 200, 50, 10 units each

self.fc1 = nn.Linear(784, 200)

self.fc2 = nn.Linear(200, 50)

# Output layer, 10 units - one for each digit

self.fc3 = nn.Linear(50, 10)

def forward(self, x):

''' Forward pass through the network, returns the output logits '''

x = self.fc1(x)

x = F.relu(x)

x = self.fc2(x)

x = F.relu(x)

x = self.fc3(x)

return x

def predict(self, x):

''' This function for predicts classes by calculating the softmax '''

logits = self.forward(x)

return F.softmax(logits, dim=1)

训练网络

现在,你应该构建你的网络并训练它了。首先,你需要定义条件(比如 nn.CrossEntropyLoss)以及优化器(比如 optim.SGD 或 optim.Adam)。

接着,你需要编写训练代码。请记住,训练传播是一个十分简明的过程:

在网络中进行前向传播来获取 logits

使用 logits 来计算损失

使用 loss.backward() 在网络中进行后向传播来计算梯度

使用优化器执行一个学习步来更新权重

通过调整超参数(隐藏单元、学习速率等),你应该可以将训练损失控制在 0.4 以下。

# TODO: Create the network, define the criterion and optimizer

net = Network()

criterion = nn.CrossEntropyLoss()

# lr = learning rate

optimizer = optim.SGD(net.parameters(), lr=0.01)

# TODO: Train the network here

print('Initial weights - ', net.fc1.weight)

print('print trainloader- ', trainloader)

dataiter = iter(trainloader)

images, labels = dataiter.next()

images.resize_(64, 784)

# Create Variables for the inputs and targets

inputs = Variable(images)

targets = Variable(labels)

# Clear the gradients from all Variables

optimizer.zero_grad()

# Forward pass, then backward pass, then update weights

output = net.forward(inputs)

loss = criterion(output, targets)

loss.backward()

print('Gradient -', net.fc1.weight.grad)

optimizer.step()

Initial weights - Parameter containing:

tensor([[-6.5417e-03, -2.4802e-02, -1.5146e-02, ..., -1.2971e-02,

-1.6070e-02, 2.6286e-02],

[ 5.7576e-03, -1.0399e-02, 4.4446e-03, ..., -2.5082e-03,

-4.3815e-03, 1.6680e-02],

[-3.4878e-02, -2.0759e-02, 1.6003e-02, ..., 8.6531e-04,

-1.9558e-02, 2.1282e-02],

...,

[ 2.6933e-02, 2.3026e-02, -3.3443e-02, ..., -2.4827e-02,

-3.5710e-02, -6.9900e-03],

[ 3.5409e-03, -3.0244e-02, 9.5727e-03, ..., -1.0316e-02,

-1.9417e-02, 1.2862e-04],

[-3.3738e-03, -3.0613e-02, 1.6543e-02, ..., 1.7032e-02,

2.3136e-02, -1.5136e-02]])

print trainloader-

Gradient - tensor([[ 1.9819e-03, 1.9818e-03, 1.9818e-03, ..., 1.9833e-03,

1.9834e-03, 1.9818e-03],

[-1.0067e-03, -1.0061e-03, -1.0094e-03, ..., -1.0024e-03,

-1.0067e-03, -1.0061e-03],

[-1.8131e-03, -1.8103e-03, -1.8119e-03, ..., -1.7803e-03,

-1.8102e-03, -1.8103e-03],

...,

[ 5.4580e-04, 5.4580e-04, 5.4289e-04, ..., 5.5021e-04,

5.4945e-04, 5.4580e-04],

[ 1.3846e-04, 1.3887e-04, 1.3865e-04, ..., 1.3844e-04,

1.3593e-04, 1.3887e-04],

[-7.1335e-04, -7.1335e-04, -7.1335e-04, ..., -7.0985e-04,

-7.1335e-04, -7.1335e-04]])

# 实际训练

epochs = 5

steps = 0

running_loss = 0

print_every = 40

for e in range(epochs):

for images,labels in iter(trainloader):

steps += 1

# Flatten MNIST images into a 784 long vector

images.resize_(images.size()[0], 784)

# Wrap images and labels in Variables so we can calculate gradients

inputs = Variable(images)

targets = Variable(labels)

optimizer.zero_grad()

output = net.forward(inputs)

loss = criterion(output, targets)

loss.backward()

optimizer.step()

running_loss += loss.data[0]

if steps % print_every == 0:

# Test accuracy

accuracy = 0

for ii, (images, labels) in enumerate(testloader):

images = images.resize_(images.size()[0], 784)

inputs = Variable(images, volatile=True)

predicted = net.predict(inputs).data

equality = (labels == predicted.max(1)[1])

accuracy += equality.type_as(torch.FloatTensor()).mean()

print("Epoch: {}/{}".format(e+1, epochs),

"Loss: {:.4f}".format(running_loss/print_every),

"Test accuracy: {:.4f}".format(accuracy/(ii+1)))

running_loss = 0

/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:23: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number

/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:31: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.

Epoch: 1/5 Loss: 2.2696 Test accuracy: 0.2491

Epoch: 1/5 Loss: 2.1425 Test accuracy: 0.4081

Epoch: 1/5 Loss: 1.9469 Test accuracy: 0.4537

Epoch: 1/5 Loss: 1.7148 Test accuracy: 0.5521

Epoch: 1/5 Loss: 1.4649 Test accuracy: 0.6200

Epoch: 1/5 Loss: 1.2773 Test accuracy: 0.6135

Epoch: 1/5 Loss: 1.1315 Test accuracy: 0.6715

Epoch: 1/5 Loss: 1.0100 Test accuracy: 0.6924

Epoch: 1/5 Loss: 0.9619 Test accuracy: 0.7131

Epoch: 1/5 Loss: 0.8913 Test accuracy: 0.7143

Epoch: 1/5 Loss: 0.8496 Test accuracy: 0.7067

Epoch: 1/5 Loss: 0.8156 Test accuracy: 0.7199

Epoch: 1/5 Loss: 0.7611 Test accuracy: 0.7322

Epoch: 1/5 Loss: 0.7401 Test accuracy: 0.7388

Epoch: 1/5 Loss: 0.7371 Test accuracy: 0.7433

Epoch: 1/5 Loss: 0.7028 Test accuracy: 0.7266

Epoch: 1/5 Loss: 0.7179 Test accuracy: 0.7473

Epoch: 1/5 Loss: 0.6788 Test accuracy: 0.7420

Epoch: 1/5 Loss: 0.6824 Test accuracy: 0.7507

Epoch: 1/5 Loss: 0.6545 Test accuracy: 0.7603

Epoch: 1/5 Loss: 0.6479 Test accuracy: 0.7602

Epoch: 1/5 Loss: 0.6444 Test accuracy: 0.7579

Epoch: 1/5 Loss: 0.6262 Test accuracy: 0.7652

Epoch: 2/5 Loss: 0.6556 Test accuracy: 0.7607

Epoch: 2/5 Loss: 0.6199 Test accuracy: 0.7681

Epoch: 2/5 Loss: 0.5855 Test accuracy: 0.7698

Epoch: 2/5 Loss: 0.6072 Test accuracy: 0.7705

Epoch: 2/5 Loss: 0.5928 Test accuracy: 0.7747

Epoch: 2/5 Loss: 0.6003 Test accuracy: 0.7717

Epoch: 2/5 Loss: 0.5904 Test accuracy: 0.7787

Epoch: 2/5 Loss: 0.5780 Test accuracy: 0.7787

Epoch: 2/5 Loss: 0.5626 Test accuracy: 0.7831

Epoch: 2/5 Loss: 0.5845 Test accuracy: 0.7889

Epoch: 2/5 Loss: 0.5900 Test accuracy: 0.7892

Epoch: 2/5 Loss: 0.5319 Test accuracy: 0.7854

Epoch: 2/5 Loss: 0.5623 Test accuracy: 0.7902

Epoch: 2/5 Loss: 0.5561 Test accuracy: 0.7937

Epoch: 2/5 Loss: 0.5558 Test accuracy: 0.7892

Epoch: 2/5 Loss: 0.5508 Test accuracy: 0.7948

Epoch: 2/5 Loss: 0.5395 Test accuracy: 0.7956

Epoch: 2/5 Loss: 0.5332 Test accuracy: 0.7911

Epoch: 2/5 Loss: 0.5422 Test accuracy: 0.8009

Epoch: 2/5 Loss: 0.5310 Test accuracy: 0.8002

Epoch: 2/5 Loss: 0.5016 Test accuracy: 0.8025

Epoch: 2/5 Loss: 0.5284 Test accuracy: 0.8026

Epoch: 2/5 Loss: 0.5133 Test accuracy: 0.7999

Epoch: 3/5 Loss: 0.5076 Test accuracy: 0.8068

Epoch: 3/5 Loss: 0.4983 Test accuracy: 0.8050

Epoch: 3/5 Loss: 0.4966 Test accuracy: 0.8049

Epoch: 3/5 Loss: 0.4893 Test accuracy: 0.7986

Epoch: 3/5 Loss: 0.5068 Test accuracy: 0.8085

Epoch: 3/5 Loss: 0.5018 Test accuracy: 0.8138

Epoch: 3/5 Loss: 0.5217 Test accuracy: 0.8073

Epoch: 3/5 Loss: 0.5191 Test accuracy: 0.8148

Epoch: 3/5 Loss: 0.5063 Test accuracy: 0.8132

Epoch: 3/5 Loss: 0.4712 Test accuracy: 0.8124

Epoch: 3/5 Loss: 0.4820 Test accuracy: 0.8118

Epoch: 3/5 Loss: 0.4893 Test accuracy: 0.8147

Epoch: 3/5 Loss: 0.5146 Test accuracy: 0.8163

Epoch: 3/5 Loss: 0.5124 Test accuracy: 0.8161

Epoch: 3/5 Loss: 0.4974 Test accuracy: 0.8133

Epoch: 3/5 Loss: 0.5093 Test accuracy: 0.8194

Epoch: 3/5 Loss: 0.4760 Test accuracy: 0.8176

Epoch: 3/5 Loss: 0.4960 Test accuracy: 0.8195

Epoch: 3/5 Loss: 0.4649 Test accuracy: 0.8154

Epoch: 3/5 Loss: 0.4778 Test accuracy: 0.8169

Epoch: 3/5 Loss: 0.5091 Test accuracy: 0.8137

Epoch: 3/5 Loss: 0.4302 Test accuracy: 0.8176

Epoch: 3/5 Loss: 0.4675 Test accuracy: 0.8230

Epoch: 3/5 Loss: 0.4825 Test accuracy: 0.8195

Epoch: 4/5 Loss: 0.4620 Test accuracy: 0.8243

Epoch: 4/5 Loss: 0.4741 Test accuracy: 0.8249

Epoch: 4/5 Loss: 0.4340 Test accuracy: 0.8180

Epoch: 4/5 Loss: 0.4573 Test accuracy: 0.8262

Epoch: 4/5 Loss: 0.4699 Test accuracy: 0.8250

Epoch: 4/5 Loss: 0.4633 Test accuracy: 0.8204

Epoch: 4/5 Loss: 0.4966 Test accuracy: 0.8244

Epoch: 4/5 Loss: 0.4854 Test accuracy: 0.8264

Epoch: 4/5 Loss: 0.4844 Test accuracy: 0.8253

Epoch: 4/5 Loss: 0.4672 Test accuracy: 0.8208

Epoch: 4/5 Loss: 0.4508 Test accuracy: 0.8267

Epoch: 4/5 Loss: 0.4514 Test accuracy: 0.8280

Epoch: 4/5 Loss: 0.4458 Test accuracy: 0.8267

Epoch: 4/5 Loss: 0.4318 Test accuracy: 0.8271

Epoch: 4/5 Loss: 0.4639 Test accuracy: 0.8262

Epoch: 4/5 Loss: 0.4509 Test accuracy: 0.8305

Epoch: 4/5 Loss: 0.4320 Test accuracy: 0.8266

Epoch: 4/5 Loss: 0.4579 Test accuracy: 0.8284

Epoch: 4/5 Loss: 0.4521 Test accuracy: 0.8237

Epoch: 4/5 Loss: 0.4405 Test accuracy: 0.8318

Epoch: 4/5 Loss: 0.4559 Test accuracy: 0.8295

Epoch: 4/5 Loss: 0.4785 Test accuracy: 0.8279

Epoch: 4/5 Loss: 0.4291 Test accuracy: 0.8318

Epoch: 5/5 Loss: 0.4580 Test accuracy: 0.8288

Epoch: 5/5 Loss: 0.4441 Test accuracy: 0.8292

Epoch: 5/5 Loss: 0.4358 Test accuracy: 0.8358

Epoch: 5/5 Loss: 0.4435 Test accuracy: 0.8337

Epoch: 5/5 Loss: 0.4557 Test accuracy: 0.8332

Epoch: 5/5 Loss: 0.4531 Test accuracy: 0.8322

Epoch: 5/5 Loss: 0.4062 Test accuracy: 0.8346

Epoch: 5/5 Loss: 0.4480 Test accuracy: 0.8331

Epoch: 5/5 Loss: 0.4449 Test accuracy: 0.8346

Epoch: 5/5 Loss: 0.4486 Test accuracy: 0.8310

Epoch: 5/5 Loss: 0.4481 Test accuracy: 0.8369

Epoch: 5/5 Loss: 0.4624 Test accuracy: 0.8359

Epoch: 5/5 Loss: 0.4464 Test accuracy: 0.8340

Epoch: 5/5 Loss: 0.4372 Test accuracy: 0.8350

Epoch: 5/5 Loss: 0.4079 Test accuracy: 0.8349

Epoch: 5/5 Loss: 0.3984 Test accuracy: 0.8368

Epoch: 5/5 Loss: 0.4247 Test accuracy: 0.8350

Epoch: 5/5 Loss: 0.4390 Test accuracy: 0.8332

Epoch: 5/5 Loss: 0.4108 Test accuracy: 0.8367

Epoch: 5/5 Loss: 0.4279 Test accuracy: 0.8362

Epoch: 5/5 Loss: 0.4078 Test accuracy: 0.8381

Epoch: 5/5 Loss: 0.4241 Test accuracy: 0.8380

Epoch: 5/5 Loss: 0.4210 Test accuracy: 0.8366

Epoch: 5/5 Loss: 0.4117 Test accuracy: 0.8317

# Test out your network!

dataiter = iter(testloader)

images, labels = dataiter.next()

img = images[0]

# Convert 2D image to 1D vector

img = img.resize_(1, 784)

# TODO: Calculate the class probabilities (softmax) for img

ps = net.predict(Variable(img.resize_(1, 784)))

# Plot the image and probabilities

helper.view_classify(img.resize_(1, 28, 28), ps)

训练好神经网络之后,你应该希望保存这个网络以便下次加载,而不是重新训练。很明显,每次使用时都重新训练网络并不现实。在实际操作中,你将会在训练网络之后将模型保存,接着重新加载网络以进行训练或是预测。在下一部分,我将为你展示如何保存和加载训练好的模型。

为者常成,行者常至

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值