# Import necessary packages.import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Set the Hyper-parameters.
input_size =28*28# 784
num_classes =10
num_epochs =10
batch_size =100
learning_rate =0.001
# Load MINST dataset (images and labels).
train_dataset = torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data',
train=False,
transform=transforms.ToTensor())# Define the Data Loader (input pipline).
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Define the Logistic regression model.
model = nn.Linear(input_size, num_classes)
# Define the Loss and Optimizer.# We use the nn.CrossEntropyLoss() computes softmax internally
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the model.
total_step =len(train_loader)for epoch inrange(num_epochs):for i,(images, labels)inenumerate(train_loader):# Reshape images to (batch_size, input_size).
images = images.reshape(-1, input_size)# Forward pass.
outputs = model(images)
loss = criterion(outputs, labels)# Backward and optimize.
optimizer.zero_grad()
loss.backward()
optimizer.step()# Set an output counter.if(i+1)%100==0:print('Epoch [{}/{}], Step [{}/{}], Loss:{:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model.# In test phase, we don't need to comput gradients (for memory efficiency)with torch.no_grad():
correct =0
total =0for images, labels in test_loader:
images = images.reshape(-1, input_size)
outputs = model(images)
_, predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct +=(predicted == labels).sum()print('Accuracy of the model on the 10000 test images: {} %'.format(100* correct / total))
Accuracy of the model on the 10000 test images: 85.55999755859375 %
# Save the model checkpoint.
torch.save(model.state_dict(),'model_param.ckpt')# torch.save(model, 'model.ckpt')
# Load the model checkpoint.
model = model.load_state_dict(torch.load('model_param.ckpt'))