Pytorch cifar100数据集的简单理解与用法

本文介绍了如何使用PyTorch的torchvision.datasets.CIFAR100数据集进行图像分类任务,包括数据下载、数据结构分析,以及如何根据需求选择和筛选训练样本。重点讲解了数据集的使用方法和常见应用场景,如增量学习和领域适应。
摘要由CSDN通过智能技术生成

https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR100.html#torchvision.datasets.CIFAR100

torchvision.datasets中提供了一些经典数据集,其中最为常用的是cifar10/100,mnist,在搓增量学习、领域自适应、主动学习等任务时经常需要打交道。这里我们以cifar100为例看一下其基本的用法。

首先,下载训练集与测试集:

from torchvision import datasets

train_dataset = datasets.cifar.CIFAR100(root='cifar100', train=True, transform=None, download=True)
test_dataset = datasets.cifar.CIFAR100(root='cifar100', train=False, transform=None, download=True)

可以看到有四个参数:

  • root:数据集文件的存储路径。
  • train:是否为训练集。True则表示视为训练集,False表示视为测试集。
  • transform:所应用的数据扩充方法。
  • download:是否下载。如果为True且root路径下无相应的数据集文件,则自动从互联网上下载数据集至给定路径。

到这里,严格来讲就算介绍完毕了,因为这里得到的train_datasettest_dataset都属于torch.utils.data.Dataset对象,剩下来的用法和我们自己手工封装数据集的一致。现在,我们着重考察下cifar100数据集的结构。

首先,直接用下标去访问一个数据集对象:

print(train_dataset[0])

输出结果如下:

(<PIL.Image.Image image mode=RGB size=32x32 at 0x25568409670>, 19)

可以看到得到的是一个tuple,第一项为image,第二项为ground truth,即图像所属的分类,使用一个int值表示。在训练计算损失函数的时候,直接使用F.cross_entropy即可,而不需要考虑将int标签转化成独热向量的形式。而对于image,在实际训练中读数据集时transform一项必然带有transforms.ToTensor(),在这种情况下返回的则是一个Tensor向量以便网络训练。

此外,还有另一种访问方法:

print(train_dataset.data[0])
print(train_dataset.targets[0])

输出结果如下:

[[[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
              ]]]
19

此时,标签仍为int不变,而数据返回的形式为ndarray:

print(train_dataset.data[0].shape)

输出结果如下:

(32, 32, 3)

可以看到cifar100图像的尺寸为32×32×3(3表示通道数)。这里我们也顺便将这张图像展示:

train_dataset[0][0].show()

在这里插入图片描述
因为只有1024像素所以非常小。查阅类别表可以发现类-19对应cattle(牛),这也与我们的观察相吻合。

现在我们来看另一个问题,可以发现第一张图像的类为19,那么这就表明,整个数据集50000张训练图像并不是按类别顺序进行划分的,即甚至可以在使用dataloader时不打开shuffle。为了验证这一点,我们直接输出train_dataset.targets中的前10个标签:

print(train_dataset.targets[:10])

结果如下:

[19, 29, 0, 11, 1, 86, 90, 28, 23, 31]

可以发现确实是乱序的。但是在一些奇怪的任务里面,会要求类别是有序的(实际上我们自己做的数据集也是尽量有序的,需要打乱通过dataloader实现即可),那么这里就看一下怎么去弄。具体来说,肯定是从targets中所包含的类别信息入手。首先将其从list转为ndarray方便我们使用numpy去操作:

train_targets = np.array(train_dataset.targets)

那么,比方说我们要取出类别在[10, 19]内的全部样本,就可以把相应的target先取出来。这里先得到一个长度为50000的包含每个元素是否满足条件的列表:

idx = np.logical_and(train_targets >= 10, train_targets < 20)
print(len(idx))
print(idx)

输出:

50000
[ True False False ... False False False]

然后使用np.where()进行广播,获得具体的下标:

idx = np.where(idx)
print(len(idx[0]))
print(idx[0])

可以看到满足要求的target所对应的下标共有5000个,确实是十分之一:

5000
[    0     3    13 ... 49977 49981 49991]

最后将这些下标作为索引取出相应的子数据集即可:

train_data = np.array(train_dataset.data)
selected_data = train_data[idx[0]]
selected_target = train_targets[idx[0]]
  • 11
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值