# coding: utf-8
# In[1]:
#模块准备
from torch.autograd import Variable
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() #把Tensor变成Image,方便可视化
# In[2]:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# In[3]:
#训练集
trainset = tv.datasets.CIFAR10(
root='/home/yablon/data/',
train=True,
download=True,
transform=transform
)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2
)
#测试集
testset = tv.datasets.CIFAR10(
root='/home/yablon/data/',
train=False,
download=True,
transform=transform
)
testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2
)
# In[4]:
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# In[5]:
(data, la
使用CUDA和pytorch框架下的CIFAR-10分类
最新推荐文章于 2024-04-18 10:32:42 发布