背景:
python;pytorch;import torch;import torchtext.data;import torchtext.dataset;
源码网址:https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py
截取部分代码如下:
从train_dataset的迭代器中得到batch的数据,batch.src.data()取出数据(加Variable的原因),
.t()矩阵转置( 转置前batch.src维度为[len,batch] )
.contiguous()返回包含与这个张量相同的数据的连续张量。
for batch in enumerate(train_iter, start=1):
src = Variable(batch.src.data.t().contiguous()) # print(src.size()) #[batch, len]
trg = batch.trg # print(trg.size()) #[len, batch]
optimizer.zero_grad()
loss = model(src, trg)
loss.backward()
其中train_iter来源:
# 使用torchtext的data和datasets,构建出迭代器,依次返回下一个batch的数据
def mt_iterator(opt, train=True):
DE = data.Field(init_token=EOS, eos_token=EOS, lower=True)
EN = data.Field(init_token=EOS, eos_token=EOS, lower=True)
train_data = datasets.TranslationDataset( path=opt.data, exts=('cn_train_pw.txt', 'en_train_pw.txt'), fields=(DE, EN) )
train_iter= data.BucketIterator( train_data, batch_size=opt.batch_size, device=0 if opt.cuda else -1 )
train_iter.repeat = False
return train_iter
data.Field源码:
"""Defines a datatype together with instructions for converting to Tensor.
Field class models common text processing datatypes that can be represented
by tensors. It holds a Vocab object that defines the set of possible values
for elements of the field and their corresponding numerical representations.
The Field object also holds other parameters relating to how a datatype
should be numericalized, such as a tokenization method and the kind of
Tensor that should be produced.
If a Field is shared between two columns in a dataset (e.g., question and
answer in a QA dataset), then they will have a shared vocabulary.
Attributes:
sequential: Whether the datatype represents sequential data. If False, no tokenization is applied. Default: True.
use_vocab: 是否使用Vocab对象. If False, the data in this field should already be numerical. Default: True.
init_token: 加在每个例子前端. Default: None.
eos_token: 加在每个例子末端. Default: None.
fix_length: 所有例子被padding到固定长度, or None for flexible sequence lengths. Default: None.
tensor_type: 类型 Default: torch.LongTensor.
preprocessing: The Pipeline that will be applied to examples using this field after tokenizing but before numericalizing. Many
Datasets replace this attribute with a custom preprocessor. Default: None.
postprocessing: A Pipeline that will be applied to examples using this field after numericalizing but before the numbers are turned
into a Tensor. The pipeline function takes the batch as a list, the field's Vocab, and train (a bool). Default: None.
lower: 是否小写 in this field. Default: False.
tokenize: The function used to tokenize strings using this field into
sequential examples. If "spacy", the SpaCy English tokenizer is used. Default: str.split.
include_lengths: 是否返回元祖(a padded minibatch, a lengths list of examples), or just a padded minibatch. Default: False.
batch_first: 是否把batch放在张量的第0维度.Default: False.
pad_token: Default: "<pad>".
unk_token: Default: "<unk>".
pad_first: 每个句子前放padding. Default: False.
"""