定义lenet神经网络
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
"""
for cifar10 dataset.
"""
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
)
#flatten
#fc unit
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5, 120),
#经过前面的卷积层和下采样层,我们得到的数据形状为16,5,5,全连接的时候要把数据打平,输入的是16*5*5的数据,输出的是120维度
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
)
self.criteon = nn.CrossEntropyLoss()
# 也可以是nn.MSELoss(),但是在分类问题上交叉熵更合适
def forward(self, x):
#所有的网络结构都需要定义forward
'''
:param input: [b, 3, 32, 32]
:return:
'''
batchsz = x.size(0)
#[b, 3, 32, 32]=>[b, 16, 5, 5]
x = self.conv_unit(x)
#[b, 16, 5, 5]=>[b, 16*5*5]
x = x.view(batchsz, 16*5*5)
#等同于x.view(batchsz, -1)
#[b, 16*5*5]=>[b, 10]
logits = self.fc_unit(x)
#self.criteon(logits, y)
return logits
def main():
net = Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet out:', out.shape)
if __name__ == '__main__':
main()
利用CIFAR10数据集训练lenet神经网络
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from Lenet5 import Lenet5
import torch.nn as nn
import torch.optim as optim
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
model = Lenet5()
criteon = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
logits = model(x)
#logist: [b, 10]
#label: [b]
loss = criteon(logits, label)
#backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
# print(correct)
acc = total_correct / total_num
print(epoch, 'test acc:', acc)
if __name__ == '__main__':
main()
结果: