MNIST数据集为0~9的数字,而CIFAR-10数据集为10类物品识别,包含飞机、车、鸟、猫等。照片大小为32*32的彩色图片(三通道)。每个类别大概有6000张照片,其中随机筛选出5000用来training,剩下的1000用来testing
首先引入数据集import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
batch_size=32
cifar_train = datasets.CIFAR10(root='cifar', train=True, transform=transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor(),
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
cifar_test = datasets.CIFAR10(root='cifar', train=False, transform=transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor(),
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
引入数据集以后,接下来开始编写经典的LeNet5神经网络import torch
from torch import nn, optim
import torch.nn.functional as F
class LeNet5(nn.Module):
"""
for CIFAR10 datasets
"""
def __init__(self):
super(LeNet5, self).__init__()
self.conv_unit = nn.Sequential(
# x: [batchsize, 3, 32, 32] => [batchsize, 6, 28, 28]
nn.Conv2d(in_channel