我们在做训练的时候,我们免不了要读入数据,而针对数据的不同存储方式,我们也有不同的读入方法,从而方便我们将训练数据与其标签一一对应上。
方式一:使用于从一个存放了所有类别数据的文件夹中读取数据。通过重写torch.utils.data.Dataset,构建数据读取方式(自己做处理将数据和标签一一对应上),最后通过迭代器 torch.utils.data.DataLoader 的调用,按照batch_size 分批次读取数据。 如下有两个例子:
- 【kaggle数据集 - dog breed 举例】数据预处理、重写Dataset、DataLoader读取数据
- torch.utils.data.Dataset 和 torch.utils.data.DataLoader 基础使用
方式二:是我们这篇文章要介绍的 torchcvision.datasets.ImageFolder, 用于从已经归好类的文件夹中读取数据,举例如下
数据存储结构:
from torchvision import datasets, transforms
my_trans = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.ToTensor()])
train_data = datasets.ImageFolder('./animals/train', transform=my_trans)
print(train_data.class_to_idx) # 查看分类名称(文件名) 对应的标签数值
print(train_data[0][0].size()) # 第一张图片的尺寸,就是我们 transforms.RandomResizedCrop 设定的裁剪尺寸
for i in range(len(train_data)):
print(train_data[i][1]) # 查看训练数据集中所有图片映射到的标签值
输出: