pytorch dataset_有关如何在PyTorch中创建深度学习数据集的教程

本篇文章包括:

  • torch.utils.data.Dataset
  • torch.utils.data.TensorDataset
  • 拆分我们的数据集:random_split
  • 准备培训:DataLoader

torch.utils.data.Dataset

该模块允许我们通过继承它来创建数据子集。我们需要定义两个函数:

  • __len __:返回数据集的长度。
  • __getitem __:在给定索引的情况下返回数据集的单个项。
class ImageClasDatasetFromFolder(Dataset): def __init__(self, path): self.path = path cls = sorted(os.listdir(path)) self.classes = dict() for i, c in enumerate(cls): self.classes.update({c: i}) self.data_list = dict() for c in cls: file_names = sorted(os.listdir(os.path.join(path, c))) for file_name in file_names: self.data_list.update({file_name: c}) def __len__(self): return len(self.data_list) def __getitem__(self, idx): file_name = list(self.data_list.keys())[idx] label = list(self.data_list.values())[idx] item = cv2.imread(os.path.join(self.path, label, file_name)) label = self.classes[label] return item, label
b2664749f08c4ab6afff6797e89a20ea

上面的示例显示了如何创建图像分类数据集,其中数据集采用文件夹格式。就像所有python类一样,__init__需要一个函数。在定义__len__和__getitem__函数之后,我们可以使用这个类创建一个包含数据集的对象,PyTorch会处理其他所有事情。

使用方法:由于我们定义__len__和__getitem__,所以我们可以使用如何定义每个函数来索引到我们的数据集。例如,len(dataset) 将返回整个数据集的长度,同时dataset[i]返回数据集的第i项。

torch.utils.data.TensorDataset

这个类与基本功能与torch.utils.data.Dataset相同,除了没有子类化,__len__和__getitem__函数已经定义了。要使用这个类,我们必须传递一些张量作为参数,其中每个张量代表数据的一部分。每个样本将通过沿第一维度索引张量来检索。最终结果是我们的数据集是相同的,我们使用len(dataset)和dataset[i]检索我们的数据。为了显示示例,我们创建一些假数据:

data = torch.randn(60000, 28, 28, 3)labels = torch.randint(0, 9, (60000,))
447f3a109edd407689a83f1be2c8a69a

上面创建的每个张量代表假图像以及假标签。由于数据集的每一项都可以沿用张量的第一维度索引,因此我们可以使用这两个张量并将它们传递到torch.utils.data.TensorDataset类中以创建与对象相同的数据集torch.utils.data.Dataset对象。

dataset = TensorDataset(data, labels)
05fb3e0c0b1547a2a2952873a470d125

拆分我们的数据集

PyTorch为我们提供了一个可用于分割数据集的函数:torch.utils.data.random_split。我们可以使用此函数将我们创建的数据集对象拆分为两个数据子集,使用下面这行代码:

train_dataset, valid_dataset = random_split(dataset, [train_num, valid_num])
524dd8a68de34fdaa5c36a292714e8c4

该random_split函数有两个参数:

  • dataset:要拆分的数据集。
  • lengths:每个子集的不同长度的列表。

在上面的代码中,由于我们希望将数据集拆分为训练集和验证集,因此我们的第二个参数是两个数字的列表,其中每个数字对应于训练和验证子集的长度。请注意,因为我们在列表中有两个值,所以我们在调用此函数时会有两个赋值的对象。

准备训练:DataLoader

使用该torch.utils.data.Dataset时仍然缺少一些东西。具体来说,我们仍然希望:

  • 将我们的数据分成批次
  • 调整我们的数据

该torch.utils.data.DataLoader类我们做了这个:

trainloader = DataLoader(train_dataset, batch_size = 24, shuffle = True)validloader = DataLoader(valid_dataset, batch_size = 24, shuffle = True)
ec2b6beec5ab4cfa8ed1c5d959d424c5

循环遍历整个数据集:我们可以使用创建的对象循环遍历我们的数据集,一次一批:

for inputs, labels in trainloader: # train using the inputs and labels tensors
0c213fd5877447768839f021d636303c

最终trainloader和validloader对象是我们在训练时使用的。

最后的说明

在本文中,我介绍了如何在PyTorch中创建数据集。我已经展示了你如何从一个从数据集转变为准备在PyTorch中训练的东西。但是,由于数据集中的数据不适合训练,因此必须对数据集进行多次预处理。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值