类似于于MNIST一样,担心代码下载速度不行,因此修改对应的源码。
下载数据
进入到
http://www.cs.toronto.edu/~kriz/cifar.html
下载数据对应的压缩包。
下载第一行的数据(当然cifar100在后面是类似的操作。)
下载到某个固定的位置,比如我,
对应的地址就是:D:\Software\DataSet\cifar
源码修改
在pycharm或者是VScode当中,按住ctrl的情况下,用鼠标点击下面这段代码中的CIFAR10
import torchvisiontorchvision.datasets.CIFAR10()
就可以看到对应的源码了。跟之前的MNIST的操作类似:
也就是将下面的注释部分的那个代码换成第一行的代码,当然地址是写你自己的。
url = "file:///D:/Software/DataSet/cifar/cifar-10-python.tar.gz"# url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
调用实例
import torchimport torchvisionimport osfrom torch.utils.data import Dataset, DataLoaderimport torchvision.utils as vutilsimport numpy as np import matplotlib.pyplot as pltimport pickleDOWNLOAD_CIFAR10 = Falsecifar10_root = './cifar10/'if not (os.path.exists(cifar10_root)) or not os.listdir(cifar10_root): # not mnist dir or mnist is empyt dir DOWNLOAD_CIFAR10 = Truetrain_data = torchvision.datasets.CIFAR10( root=cifar10_root, train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_CIFAR10,)train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True)with open('./cifar10/cifar-10-batches-py/batches.meta', 'rb') as f: data = pickle.load(f)
然后有两种不同的输出的代码:
第一种的输出:
for step, (x, y) in enumerate(train_loader): print(x.shape, y.shape) print(y) fig = plt.figure(figsize=(10, 10)) fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(10, 1.5)) for i in range(10): ax = axs[i] ax.axis("off") ax.set_title(data['label_names'][y[i]]) ax.imshow(np.transpose(x[i].numpy(), (1, 2, 0))) plt.savefig('cifar-10.png') plt.show() break
第二种的输出:
for step, (x, y) in enumerate(train_loader): print(x.shape, y.shape) print(y) fig = plt.figure(figsize=(20, 20)) for i in range(100): ax = plt.subplot(10, 10, i+1) plt.axis("off") ax.set_title(data['label_names'][y[i]]) plt.imshow(np.transpose(x[i].numpy(), (1, 2, 0))) plt.savefig('cifar-100.png') plt.show() break