(七)卷积神经网络实战
本次根据Lenet5,使用CIFAR10数据集进行卷积神经网络实战
下图是Lenet5神经网络结构,我们将根据以下结构来构建卷积神经网络
以下是源代码(可以直接运行)
lenet5.py:
import torch
from torch import nn
class Lenet5(nn.Module):
"""
for cifar10 datasets.
"""
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# input x: [b, 3, 32, 32] => [b, 6, ]
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
#
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
#
)
# flatten
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)
def forward(self, x):
batchsz = x.size(0)
# [b, 3, 32, 32] => [b, 16, 5, 5]
x = self.conv_unit(x)
# [b, 16, 5, 5] => [b, 16*5*5]
x = x.view(batchsz, 16*5*5)
# [b, 16, 5, 5] => [b, 10]
logits = self.fc_unit(x)
return logits
def main():
net = Lenet5()
if __name__ == '__main__':
main()
main.py:
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
device = torch.device('cuda')
model = Lenet5().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
total_correct += torch.eq(pred, label).float().sum().item()
total_num += x.size(0)
acc = total_correct / total_num
print(epoch, acc)
if __name__ == '__main__':
main()