模型结构及各层Tensor的shape
新版torch的代码实现及训练准确度
"""
author: zliu.elliot
@time: 2021-08-26 17:
@file: mnistWithCNN.py
"""
import torch.optim
from torchvision.transforms import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as functional
def fit(epoch, model, dataloader, criterion, optimizer, phase="train"):
if phase == "train":
model.train()
else:
model.eval()
run_loss = 0
run_correct = 0
for batch_idx, (data, target) in enumerate(dataloader):
if phase == "train":
optimizer.zero_grad()
output = model(data)
predict = output.data.max(dim=1, keepdim=True)[1]
loss = criterion(output, target)
run_loss += functional.nll_loss(output, target, reduction='sum').item()
predict = predict.reshape(target.shape)
run_correct += predict.eq(target).sum().item()
if phase == "train":
loss.backward()
optimizer.step()
loss = run_loss / len(dataloader.dataset)
accuracy = float(run_correct) / len(dataloader.dataset)
print(f"[{epoch}]{phase} loss is {loss:.3f}, accuracy is {accuracy:.3f}%")
return loss, accuracy
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root="./mnistDataset/", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./mnistDataset/", train=False, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5),
nn.Dropout2d(),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Flatten(),
nn.Linear(320, 50),
nn.ReLU(),
nn.Dropout(),
nn.Linear(50, 10),
nn.LogSoftmax()
)
loss_fn = nn.NLLLoss()
optim = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
train_loss, train_acc = [], []
val_loss, val_acc = [], []
for epoch in range(20):
epoch_loss, epoch_acc = fit(epoch, net, train_dataloader, loss_fn, optim)
train_loss.append(epoch_loss), train_acc.append(epoch_acc)
epoch_loss, epoch_acc = fit(epoch, net, test_dataloader, loss_fn, optim, phase="validation")
val_loss.append(epoch_loss), val_acc.append(epoch_acc)
print(train_loss)
print(train_acc)
print(val_loss)
print(val_acc)