python datasets 下载_CIFAR数据下载(配上Pytorch上调用)

类似于于MNIST一样,担心代码下载速度不行,因此修改对应的源码。

下载数据

进入到

http://www.cs.toronto.edu/~kriz/cifar.html

下载数据对应的压缩包。

下载第一行的数据(当然cifar100在后面是类似的操作。)

d552b0118b68885666063f98585acacf.png

下载到某个固定的位置,比如我,

对应的地址就是:D:\Software\DataSet\cifar

5080802c0aaae66024530623690bd28d.png

源码修改

在pycharm或者是VScode当中,按住ctrl的情况下,用鼠标点击下面这段代码中的CIFAR10

import torchvisiontorchvision.datasets.CIFAR10()

就可以看到对应的源码了。跟之前的MNIST的操作类似:

也就是将下面的注释部分的那个代码换成第一行的代码,当然地址是写你自己的。

url = "file:///D:/Software/DataSet/cifar/cifar-10-python.tar.gz"# url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"

调用实例

import torchimport torchvisionimport osfrom torch.utils.data import Dataset, DataLoaderimport torchvision.utils as vutilsimport numpy as np import matplotlib.pyplot as pltimport pickleDOWNLOAD_CIFAR10 = Falsecifar10_root = './cifar10/'if not (os.path.exists(cifar10_root)) or not os.listdir(cifar10_root):    # not mnist dir or mnist is empyt dir    DOWNLOAD_CIFAR10 = Truetrain_data = torchvision.datasets.CIFAR10(    root=cifar10_root,    train=True,  # this is training data    transform=torchvision.transforms.ToTensor(),  # Converts a PIL.Image or numpy.ndarray to    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]    download=DOWNLOAD_CIFAR10,)train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True)with open('./cifar10/cifar-10-batches-py/batches.meta', 'rb') as f:    data = pickle.load(f)

然后有两种不同的输出的代码:

第一种的输出:

858576c4d03c8d5f9bf6ead3bee76d60.png

for step, (x, y) in enumerate(train_loader):    print(x.shape, y.shape)    print(y)    fig = plt.figure(figsize=(10, 10))    fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(10, 1.5))    for i in range(10):        ax = axs[i]        ax.axis("off")        ax.set_title(data['label_names'][y[i]])        ax.imshow(np.transpose(x[i].numpy(), (1, 2, 0)))    plt.savefig('cifar-10.png')    plt.show()    break

第二种的输出:

ca41f83739709d07f25ab5dbcf20ed88.png

for step, (x, y) in enumerate(train_loader):    print(x.shape, y.shape)    print(y)    fig = plt.figure(figsize=(20, 20))    for i in range(100):        ax = plt.subplot(10, 10, i+1)        plt.axis("off")        ax.set_title(data['label_names'][y[i]])        plt.imshow(np.transpose(x[i].numpy(), (1, 2, 0)))    plt.savefig('cifar-100.png')    plt.show()    break

564f07d22f3872044a41f4be2c67a9ee.png

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值