学习Dataset类的来龙去脉,使用干净的代码结构,同时最大限度地减少在训练期间管理大量数据的麻烦。
神经网络训练在数据管理上可能很难做到“大规模”。
PyTorch 最近已经出现在我的圈子里,尽管对Keras和TensorFlow感到满意,但我还是不得不尝试一下。令人惊讶的是,我发现它非常令人耳目一新,非常讨人喜欢,尤其是PyTorch 提供了一个Pythonic API、一个更为固执己见的编程模式和一组很好的内置实用程序函数。我特别喜欢的一项功能是能够轻松地创建一个自定义的Dataset
对象,然后可以与内置的DataLoader
一起在训练模型时提供数据。
在本文中,我将从头开始研究PyTorchDataset
对象,其目的是创建一个用于处理文本文件的数据集,以及探索如何为特定任务优化管道。我们首先通过一个简单示例来了解Dataset
实用程序的基础知识,然后逐步完成实际任务。具体地说,我们想创建一个管道,从The Elder Scrolls(TES)系列中获取名称,这些名称的种族和性别属性作为一个one-hot张量。你可以在我的网站上找到这个数据集。
Dataset类的基础知识
Pythorch允许您自由地对“Dataset”类执行任何操作,只要您重写两个子类函数:
-返回数据集大小的函数,以及
-函数的函数从给定索引的数据集中返回一个样本。
数据集的大小有时可能是灰色区域,但它等于整个数据集中的样本数。因此,如果数据集中有10000个单词(或数据点、图像、句子等),则函数“uuLen_uUu”应该返回10000个。
PyTorch使您可以自由地对Dataset
类执行任何操作,只要您重写改类中的两个函数即可:
__len__
函数:返回数据集大小__getitem__
函数:返回对应索引的数据集中的样本
数据集的大小有时难以确定,但它等于整个数据集中的样本数量。因此,如果您的数据集中有10,000个样本(数据点,图像,句子等),则__len__
函数应返回10,000。
一个简单示例
首先,创建一个从1到1000所有数字的Dataset
来模拟一个简单的数据集。我们将其适当地命名为NumbersDataset
。
from torch.utils.data import Dataset
class NumbersDataset(Dataset):
def __init__(self):
self.samples = list(range(1, 1001))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
if __name__ == '__main__':
dataset = NumbersDataset()
print(len(dataset))
print(dataset[100])
print(dataset[122:361])
很简单,对吧?首先,当我们初始化NumbersDataset
时,我们立即创建一个名为samples
的列表,该列表将存储1到1000之间的所有数字。列表的名称是任意的,因此请随意使用您喜欢的名称。需要重写的函数是不用我说明的(我希望!),并且对在构造函数中创建的列表进行操作。如果运行该python文件,将看到1000、101和122到361之间的值,它们分别指的是数据集的长度,数据集中索引为100的数据以及索引为121到361之间的数据集切片。
扩展数据集
让我们扩展此数据集,以便它可以存储low
和high
之间的所有整数。
from torch.utils.data import Dataset
class NumbersDataset(Dataset):
def __init__(self, low, high):
self.samples = list(range(low, high))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
if __name__ == '__main__':
dataset = NumbersDataset(2821, 8295)
print(len(dataset))
print(dataset[100])
print(dataset[122:361])
运行上面代码应在控制台打印5474、2921和2943到3181之间的数字。通过编写构造函数,我们现在可以将数据集的low
和high
设置为我们的想要的内容。这个简单的更改显示了我们可以从PyTorch的Dataset
类获得的各种好处。例如,我们可以生成多个不同的数据集并使用这些值,而不必像在NumPy中那样,考虑编写新的类或创建许多难以理解的矩阵。
从文件读取数据
让我们来进一步扩展Dataset
类的功能。PyTorch与Python标准库的接口设计得非常优美,这意味着您不必担心集成功能。在这里,我们将
- 创建一个全新的使用Python I/O和一些静态文件的
Dataset
类 - 收集TES角色名称(我的网站上有可用的数据集),这些角色名称分为种族文件夹和性别文件,以填充
samples
列表 - 通过在
samples
列表中存储一个元组而不只是名称本身来跟踪每个名称的种族和性别。
TES名称数据集具有以下目录结构:
.
|-- Al