实验室这两天的网太差了,数据集下载特别特别特别慢…………
网络结构在这篇博客里——pytorch创建卷积神经网络
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from lenet5 import Lenet5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]), download=True)
cifar_teat = DataLoader(cifar_train,batch_s