在PyTorch中构建高效的自定义数据集

学习Dataset类的来龙去脉,使用干净的代码结构,同时最大限度地减少在训练期间管理大量数据的麻烦。

1_2ARG_iUVAzMGgKtTMnGQOg

神经网络训练在数据管理上可能很难做到“大规模”。

PyTorch 最近已经出现在我的圈子里,尽管对Keras和TensorFlow感到满意,但我还是不得不尝试一下。令人惊讶的是,我发现它非常令人耳目一新,非常讨人喜欢,尤其是PyTorch 提供了一个Pythonic API、一个更为固执己见的编程模式和一组很好的内置实用程序函数。我特别喜欢的一项功能是能够轻松地创建一个自定义的Dataset对象,然后可以与内置的DataLoader一起在训练模型时提供数据。

在本文中,我将从头开始研究PyTorchDataset对象,其目的是创建一个用于处理文本文件的数据集,以及探索如何为特定任务优化管道。我们首先通过一个简单示例来了解Dataset实用程序的基础知识,然后逐步完成实际任务。具体地说,我们想创建一个管道,从The Elder Scrolls(TES)系列中获取名称,这些名称的种族和性别属性作为一个one-hot张量。你可以在我的网站上找到这个数据集。

Dataset类的基础知识

Pythorch允许您自由地对“Dataset”类执行任何操作,只要您重写两个子类函数:

-返回数据集大小的函数,以及

-函数的函数从给定索引的数据集中返回一个样本。

数据集的大小有时可能是灰色区域,但它等于整个数据集中的样本数。因此,如果数据集中有10000个单词(或数据点、图像、句子等),则函数“uuLen_uUu”应该返回10000个。

PyTorch使您可以自由地对Dataset类执行任何操作,只要您重写改类中的两个函数即可:

  • __len__ 函数:返回数据集大小
  • __getitem__ 函数:返回对应索引的数据集中的样本

数据集的大小有时难以确定,但它等于整个数据集中的样本数量。因此,如果您的数据集中有10,000个样本(数据点,图像,句子等),则__len__函数应返回10,000。

一个简单示例

首先,创建一个从1到1000所有数字的Dataset来模拟一个简单的数据集。我们将其适当地命名为NumbersDataset

from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 1001))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


if __name__ == '__main__':
    dataset = NumbersDataset()
    print(len(dataset))
    print(dataset[100])
    print(dataset[122:361])

1_M0Qw7bS02uzqEH5Q-tAR1w

很简单,对吧?首先,当我们初始化NumbersDataset时,我们立即创建一个名为samples的列表,该列表将存储1到1000之间的所有数字。列表的名称是任意的,因此请随意使用您喜欢的名称。需要重写的函数是不用我说明的(我希望!),并且对在构造函数中创建的列表进行操作。如果运行该python文件,将看到1000、101和122到361之间的值,它们分别指的是数据集的长度,数据集中索引为100的数据以及索引为121到361之间的数据集切片。

扩展数据集

让我们扩展此数据集,以便它可以存储lowhigh之间的所有整数。

from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    def __init__(self, low, high):
        self.samples = list(range(low, high))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


if __name__ == '__main__':
    dataset = NumbersDataset(2821, 8295)
    print(len(dataset))
    print(dataset[100])
    print(dataset[122:361])

1_M4NsNPSdxokejaxrftnrlg

运行上面代码应在控制台打印5474、2921和2943到3181之间的数字。通过编写构造函数,我们现在可以将数据集的lowhigh设置为我们的想要的内容。这个简单的更改显示了我们可以从PyTorch的Dataset类获得的各种好处。例如,我们可以生成多个不同的数据集并使用这些值,而不必像在NumPy中那样,考虑编写新的类或创建许多难以理解的矩阵。

从文件读取数据

让我们来进一步扩展Dataset类的功能。PyTorch与Python标准库的接口设计得非常优美,这意味着您不必担心集成功能。在这里,我们将

  • 创建一个全新的使用Python I/O和一些静态文件的Dataset
  • 收集TES角色名称(我的网站上有可用的数据集),这些角色名称分为种族文件夹和性别文件,以填充samples列表
  • 通过在samples列表中存储一个元组而不只是名称本身来跟踪每个名称的种族和性别。

TES名称数据集具有以下目录结构:

.
|-- Al
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值