#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import torchvision
import matplotlib.pyplot as plt
# In[2]:
# Set parameters
batchSize = 64
learningRate = 0.1
epochNum = 10
# Download MNIST dataset
trainDataset = torchvision.datasets.MNIST('./data', train = True, transform = torchvision.transforms.ToTensor(), download = True)
valDataset = torchvision.datasets.MNIST('./data', train = False, transform = torchvision.transforms.ToTensor(), download = True)
# Loading data
trainData = torch.utils.data.DataLoader(trainDataset, batch_size = batchSize, shuffle = True, drop_last = True)
valData = torch.utils.data.DataLoader(valDataset, batch_size = batchSize, shuffle = True, drop_last = True)
# In[3]:
net = torch.nn.Sequential(
torch.nn.Linear(784, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 300),
torch.nn.ReLU(),
torch.nn.Linear(300, 10),
# torch.nn.Softmax(dim = 1)
)
# In[4]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = learningRate)
# In[5]:
losses = []
acces = []
valLosses = []
valAcces = []
for epoch in range(epochNum):
print(epoch)
for idx, (img, lbl) in enumerate(trainData):
net.train()
img = img.reshape((batchSize, -1))
try:
out = net(img)
loss = criterion(out, lbl)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, result = out.max(1)
acc = (result == lbl).sum().item()/img.shape[0] * 100
losses.append(loss.item())
acces.append(acc)
###############################
net.eval()
img, lbl = next(iter(valData))
img = img.reshape((batchSize, -1))
out = net(img)
loss = criterion(out, lbl)
_, result = out.max(1)
acc = (result == lbl).sum().item()/img.shape[0] * 100
valLosses.append(loss.item())
valAcces.append(acc)
###########################
except Exception as exc:
print(exc)
plt.plot(acces)
plt.plot(valAcces)
plt.show()
plt.plot(losses)
plt.plot(valLosses)
plt.show()
# In[ ]:
MNIST Handwritten Digit Recognition in PyTorch (Nerual Network)
最新推荐文章于 2024-08-28 01:18:44 发布