python生成10000个样本数据集_在PyTorch中构建高效的自定义数据集

本文介绍了如何在PyTorch中创建自定义Dataset对象,以处理大规模数据,包括从文件读取、数据预处理、独热编码等。通过示例展示了如何创建NumbersDataset和TESNamesDataset,详细解释了Dataset类的__len__和__getitem__方法,并演示了DataLoader如何加载和处理数据。最后讨论了数据集的拆分和处理不同长度样本的方法,强调了PyTorch数据处理的灵活性和高效性。
摘要由CSDN通过智能技术生成

学习Dataset类的来龙去脉,使用干净的代码结构,同时最大限度地减少在训练期间管理大量数据的麻烦。

1981858-20200714120348709-246096049.png

神经网络训练在数据管理上可能很难做到“大规模”。

PyTorch 最近已经出现在我的圈子里,尽管对Keras和TensorFlow感到满意,但我还是不得不尝试一下。令人惊讶的是,我发现它非常令人耳目一新,非常讨人喜欢,尤其是PyTorch 提供了一个Pythonic API、一个更为固执己见的编程模式和一组很好的内置实用程序函数。我特别喜欢的一项功能是能够轻松地创建一个自定义的Dataset对象,然后可以与内置的DataLoader一起在训练模型时提供数据。

在本文中,我将从头开始研究PyTorchDataset对象,其目的是创建一个用于处理文本文件的数据集,以及探索如何为特定任务优化管道。我们首先通过一个简单示例来了解Dataset实用程序的基础知识,然后逐步完成实际任务。具体地说,我们想创建一个管道,从The Elder Scrolls(TES)系列中获取名称,这些名称的种族和性别属性作为一个one-hot张量。你可以在我的网站上找到这个数据集。

Dataset类的基础知识

Pythorch允许您自由地对“Dataset”类执行任何操作,只要您重写两个子类函数:

-返回数据集大小的函数,以及

-函数的函数从给定索引的数据集中返回一个样本。

数据集的大小有时可能是灰色区域,但它等于整个数据集中的样本数。因此,如果数据集中有10000个单词(或数据点、图像、句子等),则函数“uuLen_uUu”应该返回10000个。

PyTorch使您可以自由地对Dataset类执行任何操作,只要您重写改类中的两个函数即可:

__len__ 函数:返回数据集大小

__getitem__ 函数:返回对应索引的数据集中的样本

数据集的大小有时难以确定,但它等于整个数据集中的样本数量。因此,如果您的数据集中有10,000个样本(数据点,图像,句子等),则__len__函数应返回10,000。

一个简单示例

首先,创建一个从1到1000所有数字的Dataset来模拟一个简单的数据集。我们将其适当地命名为NumbersDataset。

from torch.utils.data import Dataset

class NumbersDataset(Dataset):

def __init__(self):

self.samples = list(range(1, 1001))

def __len__(self):

return len(self.samples)

def __getitem__(self, idx):

return self.samples[idx]

if __name__ == '__main__':

dataset = NumbersDataset()

print(len(dataset))

print(dataset[100])

print(dataset[122:361])

1981858-20200714120349086-1718249959.png

很简单,对吧?首先,当我们初始化NumbersDataset时,我们立即创建一个名为samples的列表,该列表将存储1到1000之间的所有数字。列表的名称是任意的,因此请随意使用您喜欢的名称。需要重写的函数是不用我说明的(我希望!),并且对在构造函数中创建的列表进行操作。如果运行该python文件,将看到1000、101和122到361之间的值,它们分别指的是数据集的长度,数据集中索引为100的数据以及索引为121到361之间的数据集切片。

扩展数据集

让我们扩展此数据集,以便它可以存储low和high之间的所有整数。

from torch.utils.data import Dataset

class NumbersDataset(Dataset):

def __init__(self, low, high):

self.samples = list(range(low, high))

def __len__(self):

return len(self.samples)

def __getitem__(self, idx):

return self.samples[idx]

if __name__ == '__main__':

dataset = NumbersDataset(2821, 8295)

print(len(dataset))

print(dataset[100])

print(dataset[122:361])

1981858-20200714120349395-1068803086.png

运行上面代码应在控制台打印5474、2921和2943到3181之间的数字。通过编写构造函数,我们现在可以将数据集的low和high设置为我们的想要的内容。这个简单的更改显示了我们可以从PyTorch的Dataset类获得的各种好处。例如,我们可以生成多个不同的数据集并使用这些值,而不必像在NumPy中那样,考虑编写新的类或创建许多难以理解的矩阵。

从文件读取数据

让我们来进一步扩展Dataset类的功能。PyTorch与Python标准库的接口设计得非常优美,这意味着您不必担心集成功能。在这里,我们将

创建一个全新的使用Python I/O和一些静态文件的Dataset类

收集TES角色名

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值