"""prepare the data"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root = './data',train = True,download = True,transform = transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = 4,shuffle = True,num_workers = 2)
testset = torchvision.datasets.CIFAR10(root = './data',train = False,download = True,transform = transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = 4,shuffle = False,num_workers = 2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
"""prepare the net&
pytorch学习笔记(一)cifar10分类
最新推荐文章于 2023-05-30 17:39:36 发布