PyTorch编程基础07--处理数据集的接口与使用

1.Dataset基类

在PyTorch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法。

  1. __len__方法,能够实现通过全局的len()方法获取其中的元素个数
  2. __getitem__方法,能够通过传入索引的方式获取数值,例如通过dataset[i]获取其中的第i条数据

2.用DataLoader类实现自定义数据集

在PyTorch中,用torch.utils.data.DataLoader类可以构建带有批次的数据集。

2.1DataLoader类的定义

class DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0,
collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None)

具体参数解读如下:

  • dataset:待加载的数据
  • batch_size:每批次加载的样本数量,默认是1
  • shuffle:是否把样本的顺序打乱。默认是False,表示不打乱样本的顺序
  • sampler:接收一个采样器对象,用于按照指定的样本提取策略从数据集中提取样本。如果指定,则忽略shuffle参数
  • num_workers:设置加载数据的额外进程数量。默认是0,表示不额外启动进程来加载数据,直接使用主进程对数据进行加载
  • collate_fn:接收一个自定义函数。当该参数不为None时,系统会先从数据集中取出数据,然后将数据传入collate_fn中,由collate_fn参数所指向的函数对数据进行二次加工。collate_fn常用于测试和训练场景中对同一个数据集进行数据提取
  • pin_memory:在数据返回前,是否将数据复制到CUDA内存中。默认值为False
  • drop_last:是否丢弃最后数据,默认值是False,表示不丢弃。在样本总数不能被batch size整除的情况下,如果该值为True,则丢弃最后一个满足一个批次数量的数据;如果该值为False,则将最后不足一个批次数量的数据返回
  • timeout:读取数据的超时时间,默认值为0。当超时设置时间还没读到数据是,系统就会报错
  • worker_init_fn:每个子进程的初始化函数,在加载数据之前运行
  • multiprocessing_context:多进程处理的配置参数

3.DataLoader类中的多采样器子类

DataLoader类是一个非常强大的数据集处理类。它几乎可以覆盖数据集的任何使用场景,在PyTorch程序中也非常常用。其中与DataLoader类配套的还有采样器sampler类,该类又派生了多个采样器子类,同时支持自定义采样器类。其中,内置的采样器子类有如下几种

  • SequentialSampler:按照原有的样本顺序进行采样
  • RandomSampler:按照随机顺序进行采样,可以设置是否重复采样
  • SubsetRandomSampler:按照指定的集合或索引列表进行随机顺序采样
  • WeightedRandomSampler:按照指定的概率进行随机顺序采样
  • BatchSampler:按照指定的批次索引进行采样

4.Torchtext工具与内置数据集

4.1Torchtext的内部结构

Torchtext对数据的处理可以概括为Field、Dataset和迭代器这3部分

  • Field:要如何处理某个字段
  • Dataset:定义数据源信息
  • 迭代器:返回模型所需要的处理后的数据。主要分为以下3种
    • Iterator:标准迭代器
    • BucketIerator:相比于标准迭代器,它会将类似长度的样本当作一批来处理。由于在文本处理中经常需要将“每一批样本长度”补齐为“当前批次中最长序列的长度”,因此当样本长度差别较大时,使用BucketIterator可以带来填充效率的提高。除此之外,还可以在Field中通过fix_length参数来对样本进行截断补齐操作
    • BPTTIterator:基于BPTT(基于时间的反向传播算法)的迭代器,一般用于语言模型中。

4.2安装Torchtext库

pip install torchtext

4.3查看Torchtext库的内置数据集

在安装好Torchtext库后,可以在如下的路径中查看Torchtext库的内置数据集

本地pip安装包路径\Lib\site-packages\torchtext\datases\__init__.py

4.4安装Torchtext库的调用模块

在使用Torchtext库过程中,如果要间接使用其他的文本处理库,则还需要额外下载,例如,使用字段处理库的代码如下:

from torchtext import data
TEXT = data.Field(tokenize='spacy')

在调用Torchtext库的data.Field()函数时,可以向tokenize参数传入“revtok”、“subword”、“spacy”、“moses”字符串,表示分别使用revtok、NLTK、en模块的SpaCy库、sacremoses库进行字段处理,这些库都需要单独安装。


4.5Torchtext库的内置预训练词向量

Torchtext库中内置了若干个预训练词向量,可以在模型中直接用来对本地的权重进行初始化

charngram.100d、fasttext.en.300d、fasttext.simple.300d、glove.42B.300d、glove.840B.300d、glove.twitter.27B.25d、glove.twitter.27B.50d、glove.twitter.27B.100d、glove.twitter.27B.200d、glove.6B.50d、glove.6B.100d、glove.6B.200d、glove.6B.300d

这些词向量,前部分的名称表明其在训练时所用的模型,后部分都是“数字+d”的形式,代表将词映射成词向量的维度。这种本来就带有语义的词向量,可以大大加快模型的训练速度。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值