本篇代码有不清楚的地方,可以参考:
cifar-10+resnet.
这篇除了搭建的CNN不一样,其他地方完全一样。
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose(( #true表示加载的是训练集
transforms.Resize(32,32),
transforms.ToTensor())))
cifar_train_batch = DataLoader(cifar_train,batch_size = 30,shuffle = True)
cifar_test = datasets.CIFAR10('cifar',False,transform=transforms.Compose(( #false表示加载的是测试集
transforms.Resize(32,32),
transforms.ToTensor())))
cifar_test_batch = DataLoader(cifar_test_one,batch_size = 30,shuffle = True)
搭建CNN:
from torch import nn
class lenet5(nn.Module)