import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# Download training data from open datasets.
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
9913344/? [00:00<00:00, 37011789.51it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
29696/? [00:00<00:00, 486019.07it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
1649664/? [00:00<00:00, 6943416.95it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
5120/? [00:00<00:00, 214900.94it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
batch_size = 128
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
Shape of X [N, C, H, W]: torch.Size([128, 1, 28, 28]) Shape of y: torch.Size([128]) torch.int64
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 3), # in_channels, out_channels, kernel_size
nn.ReLU(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(64, 128, 3),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(128*5*5, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
feature = self.conv(x)
logits = self.fc(feature)
return logits
model = CNN().to(device)
print(model)
Using cuda device CNN( (conv): Sequential( (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1)) (1): ReLU() (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1)) (4): ReLU() (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (fc): Sequential( (0): Flatten(start_dim=1, end_dim=-1) (1): Linear(in_features=3200, out_features=128, bias=True) (2): ReLU() (3): Linear(in_features=128, out_features=10, bias=True) ) )
!pip install torchsummary
from torchsummary import summary
Collecting torchsummary Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB) Installing collected packages: torchsummary Successfully installed torchsummary-1.5.1 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv class="ansi-yellow-fg">
summary(model,(1,28,28))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 26, 26] 640 ReLU-2 [-1, 64, 26, 26] 0 MaxPool2d-3 [-1, 64, 13, 13] 0 Conv2d-4 [-1, 128, 11, 11] 73,856 ReLU-5 [-1, 128, 11, 11] 0 MaxPool2d-6 [-1, 128, 5, 5] 0 Flatten-7 [-1, 3200] 0 Linear-8 [-1, 128] 409,728 ReLU-9 [-1, 128] 0 Linear-10 [-1, 10] 1,290 ================================================================ Total params: 485,514 Trainable params: 485,514 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.00 Forward/backward pass size (MB): 1.03 Params size (MB): 1.85 Estimated Total Size (MB): 2.88 ----------------------------------------------------------------
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.train()
train_loss, correct = 0, 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
correct /= size
train_loss /= num_batches
print(f"train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f} \n")
def val(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
val_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
val_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
val_loss /= num_batches
correct /= size
print(f"val Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
val(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1 ------------------------------- loss: 2.315103 [ 0/60000] loss: 0.152633 [12800/60000] loss: 0.107560 [25600/60000] loss: 0.076913 [38400/60000] loss: 0.128909 [51200/60000] train Error: Accuracy: 93.5%, Avg loss: 0.217212 val Error: Accuracy: 98.0%, Avg loss: 0.059248 Epoch 2 ------------------------------- loss: 0.071140 [ 0/60000] loss: 0.039843 [12800/60000] loss: 0.048955 [25600/60000] loss: 0.086715 [38400/60000] loss: 0.090539 [51200/60000] train Error: Accuracy: 98.4%, Avg loss: 0.053457 val Error: Accuracy: 98.4%, Avg loss: 0.045729 Epoch 3 ------------------------------- loss: 0.056194 [ 0/60000] loss: 0.029174 [12800/60000] loss: 0.032370 [25600/60000] loss: 0.068097 [38400/60000] loss: 0.062042 [51200/60000] train Error: Accuracy: 98.9%, Avg loss: 0.036438 val Error: Accuracy: 98.5%, Avg loss: 0.040602 Epoch 4 ------------------------------- loss: 0.046150 [ 0/60000] loss: 0.032029 [12800/60000] loss: 0.034310 [25600/60000] loss: 0.055553 [38400/60000] loss: 0.047306 [51200/60000] train Error: Accuracy: 99.1%, Avg loss: 0.028400 val Error: Accuracy: 98.6%, Avg loss: 0.041201 Epoch 5 ------------------------------- loss: 0.021812 [ 0/60000] loss: 0.035993 [12800/60000] loss: 0.037362 [25600/60000] loss: 0.030702 [38400/60000] loss: 0.045520 [51200/60000] train Error: Accuracy: 99.4%, Avg loss: 0.021017 val Error: Accuracy: 98.4%, Avg loss: 0.047126 Epoch 6 ------------------------------- loss: 0.017351 [ 0/60000] loss: 0.044284 [12800/60000] loss: 0.026712 [25600/60000] loss: 0.021054 [38400/60000] loss: 0.050305 [51200/60000] train Error: Accuracy: 99.5%, Avg loss: 0.016301 val Error: Accuracy: 98.6%, Avg loss: 0.042306 Epoch 7 ------------------------------- loss: 0.006217 [ 0/60000] loss: 0.046209 [12800/60000] loss: 0.029611 [25600/60000] loss: 0.012666 [38400/60000] loss: 0.016735 [51200/60000] train Error: Accuracy: 99.6%, Avg loss: 0.013194 val Error: Accuracy: 98.5%, Avg loss: 0.053701 Epoch 8 ------------------------------- loss: 0.008757 [ 0/60000] loss: 0.008058 [12800/60000] loss: 0.021122 [25600/60000] loss: 0.008791 [38400/60000] loss: 0.039546 [51200/60000] train Error: Accuracy: 99.7%, Avg loss: 0.010438 val Error: Accuracy: 98.9%, Avg loss: 0.039033 Epoch 9 ------------------------------- loss: 0.007781 [ 0/60000] loss: 0.019456 [12800/60000] loss: 0.003024 [25600/60000] loss: 0.023416 [38400/60000] loss: 0.007457 [51200/60000] train Error: Accuracy: 99.7%, Avg loss: 0.009333 val Error: Accuracy: 98.9%, Avg loss: 0.040847 Epoch 10 ------------------------------- loss: 0.007596 [ 0/60000] loss: 0.004938 [12800/60000] loss: 0.007092 [25600/60000] loss: 0.026191 [38400/60000] loss: 0.066239 [51200/60000] train Error: Accuracy: 99.7%, Avg loss: 0.007795 val Error: Accuracy: 98.9%, Avg loss: 0.044136 Done!