本篇文章包括:
- torch.utils.data.Dataset
- torch.utils.data.TensorDataset
- 拆分我们的数据集:random_split
- 准备培训:DataLoader
torch.utils.data.Dataset
该模块允许我们通过继承它来创建数据子集。我们需要定义两个函数:
- __len __:返回数据集的长度。
- __getitem __:在给定索引的情况下返回数据集的单个项。
class ImageClasDatasetFromFolder(Dataset): def __init__(self, path): self.path = path cls = sorted(os.listdir(path)) self.classes = dict() for i, c in enumerate(cls): self.classes.update({c: i}) self.data_list = dict() for c in cls: file_names = sorted(os.listdir(os.path.join(path, c))) for file_name in file_names: self.data_list.update({file_name: c}) def __len__(self): return len(self.data_list) def __getitem__(self, idx): file_name = list(self.data_list.keys())[idx] label = list(self.data_list.values())[idx] item = cv2.imread(os.path.join(self.path, label, file_name)) label = self.classes[label] return item, label
上面的示例显示了如何创建图像分类数据集,其中数据集采用文件夹格式。就像所有python类一样,__init__需要一个函数。在定义__len__和__getitem__函数之后,我们可以使用这个类创建一个包含数据集的对象,PyTorch会处理其他所有事情。
使用方法:由于我们定义__len__和__getitem__,所以我们可以使用如何定义每个函数来索引到我们的数据集。例如,len(dataset) 将返回整个数据集的长度,同时dataset[i]返回数据集的第i项。
torch.utils.data.TensorDataset
这个类与基本功能与torch.utils.data.Dataset相同,除了没有子类化,__len__和__getitem__函数已经定义了。要使用这个类,我们必须传递一些张量作为参数,其中每个张量代表数据的一部分。每个样本将通过沿第一维度索引张量来检索。最终结果是我们的数据集是相同的,我们使用len(dataset)和dataset[i]检索我们的数据。为了显示示例,我们创建一些假数据:
data = torch.randn(60000, 28, 28, 3)labels = torch.randint(0, 9, (60000,))
上面创建的每个张量代表假图像以及假标签。由于数据集的每一项都可以沿用张量的第一维度索引,因此我们可以使用这两个张量并将它们传递到torch.utils.data.TensorDataset类中以创建与对象相同的数据集torch.utils.data.Dataset对象。
dataset = TensorDataset(data, labels)
拆分我们的数据集
PyTorch为我们提供了一个可用于分割数据集的函数:torch.utils.data.random_split。我们可以使用此函数将我们创建的数据集对象拆分为两个数据子集,使用下面这行代码:
train_dataset, valid_dataset = random_split(dataset, [train_num, valid_num])
该random_split函数有两个参数:
- dataset:要拆分的数据集。
- lengths:每个子集的不同长度的列表。
在上面的代码中,由于我们希望将数据集拆分为训练集和验证集,因此我们的第二个参数是两个数字的列表,其中每个数字对应于训练和验证子集的长度。请注意,因为我们在列表中有两个值,所以我们在调用此函数时会有两个赋值的对象。
准备训练:DataLoader
使用该torch.utils.data.Dataset时仍然缺少一些东西。具体来说,我们仍然希望:
- 将我们的数据分成批次
- 调整我们的数据
该torch.utils.data.DataLoader类我们做了这个:
trainloader = DataLoader(train_dataset, batch_size = 24, shuffle = True)validloader = DataLoader(valid_dataset, batch_size = 24, shuffle = True)
循环遍历整个数据集:我们可以使用创建的对象循环遍历我们的数据集,一次一批:
for inputs, labels in trainloader: # train using the inputs and labels tensors
最终trainloader和validloader对象是我们在训练时使用的。
最后的说明
在本文中,我介绍了如何在PyTorch中创建数据集。我已经展示了你如何从一个从数据集转变为准备在PyTorch中训练的东西。但是,由于数据集中的数据不适合训练,因此必须对数据集进行多次预处理。