1.问题
\quad
在使用Pytorch的时候,有时候需要在线下载数据集,因为在下载的过程中,封装好的代码,还要进行其他的操作(例如数据类型转换numpy->tensor
),但是有时候因为下载网站在国外,进度条一直显示0%,
\quad
就像这样:
2.解决办法
-
step1.下载数据集到本地
-
step2. 将本地存放CIFAR数据集路径放在浏览器下,回车
-
step3. 修改
class CIFAR10(VisionDataset)
中的url
-
step4. 运行代码
import torch import torchvision LOAD_CIFAR = True DOWNLOAD_CIFAR = True train_data = torchvision.datasets.CIFAR10( root='./cifar10/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_CIFAR, )
一切ok
3.显示CIFAR10的图像
import torch
import torchvision
import matplotlib.pyplot as plt
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_CIFAR = False
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
train_data = torchvision.datasets.CIFAR10(
root='./cifar10/', # 保存或者提取位置
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
download=DOWNLOAD_CIFAR, # 没下载就下载, 下载了就不用再下了
)
# method 1
# dataiter = iter(train_data)
# plt.show()
# for _ in range(len(train_data)):
# images, labels = dataiter.__next__()
# images = images.numpy().transpose(1, 2, 0) # 把channel那一维放到最后
# plt.title(str(classes[labels]))
# plt.imshow(images)
# plt.pause(1)
# as 2 list
# method 2
plt.show()
for images, labels in train_data:
images = images.numpy().transpose(1, 2, 0) # 把channel那一维放到最后
plt.title(str(classes[labels]))
plt.imshow(images)
plt.pause(1)
显示出来的图像