目录
LeNet网络实战
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn
from lenet5 import Lenet5
from torch import optim
def main():
batch_size=32
cifar_train=datasets.CIFAR10('cifar',True,transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]),download=True)
cifar_train=DataLoader(cifar_train,batch_size=batch_size,shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
x,label=iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
model.train()
device = torch.device('cuda')
model = Lenet5().to(device)
# CrossEntropyLoss进行交叉熵运算(包括了softmax计算),判断多分类问题中预测试与真实值的差距
criteon = nn.CrossEntropyLoss().to(d