pytorch代码torchtext代码batch.src.data.t().contiguous()

背景:

python;pytorch;import torch;import torchtext.data;import torchtext.dataset;

源码网址:https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py

截取部分代码如下:

从train_dataset的迭代器中得到batch的数据,batch.src.data()取出数据(加Variable的原因),

.t()矩阵转置( 转置前batch.src维度为[len,batch] )

.contiguous()返回包含与这个张量相同的数据的连续张量。

    for batch in enumerate(train_iter, start=1): 
        src = Variable(batch.src.data.t().contiguous())    # print(src.size())  #[batch, len]  
        trg = batch.trg                                                                # print(trg.size())  #[len, batch]
        optimizer.zero_grad()
        loss = model(src, trg)
        loss.backward() 

其中train_iter来源:

# 使用torchtext的data和datasets,构建出迭代器,依次返回下一个batch的数据

def mt_iterator(opt, train=True):
    DE = data.Field(init_token=EOS, eos_token=EOS, lower=True)
    EN = data.Field(init_token=EOS, eos_token=EOS, lower=True) 
    train_data = datasets.TranslationDataset( path=opt.data, exts=('cn_train_pw.txt', 'en_train_pw.txt'), fields=(DE, EN) )
    train_iter= data.BucketIterator( train_data, batch_size=opt.batch_size, device=0 if opt.cuda else -1 )
    train_iter.repeat = False 
    return train_iter


data.Field源码:

"""Defines a datatype together with instructions for converting to Tensor.
    Field class models common text processing datatypes that can be represented
    by tensors.  It holds a Vocab object that defines the set of possible values
    for elements of the field and their corresponding numerical representations.
    The Field object also holds other parameters relating to how a datatype
    should be numericalized, such as a tokenization method and the kind of
    Tensor that should be produced.
    If a Field is shared between two columns in a dataset (e.g., question and
    answer in a QA dataset), then they will have a shared vocabulary.
    Attributes:
        sequential: Whether the datatype represents sequential data. If False,  no tokenization is applied. Default: True.
        use_vocab: 是否使用Vocab对象. If False, the data in this  field should already be numerical. Default: True.
        init_token: 加在每个例子前端. Default: None.
        eos_token: 加在每个例子末端. Default: None.
        fix_length: 所有例子被padding到固定长度, or None for flexible sequence lengths. Default: None.
        tensor_type: 类型 Default: torch.LongTensor.
        preprocessing: The Pipeline that will be applied to examples  using this field after tokenizing but before numericalizing. Many
            Datasets replace this attribute with a custom preprocessor.  Default: None.
        postprocessing: A Pipeline that will be applied to examples using  this field after numericalizing but before the numbers are turned
            into a Tensor. The pipeline function takes the batch as a list,  the field's Vocab, and train (a bool).  Default: None.
        lower: 是否小写 in this field. Default: False.
        tokenize: The function used to tokenize strings using this field into
            sequential examples. If "spacy", the SpaCy English tokenizer is used. Default: str.split.
        include_lengths: 是否返回元祖(a padded minibatch, a lengths list of examples), or just a padded minibatch. Default: False.
        batch_first: 是否把batch放在张量的第0维度.Default: False.
        pad_token:  Default: "<pad>".
        unk_token:  Default: "<unk>".
        pad_first: 每个句子前放padding. Default: False.
    """

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值