原文链接:
本教程展示了如何从 torchtext 中的旧 API 迁移到 0.9.0 版本中的新 API。这里,我们以 IMDB 数据集为例进行情感分析。torchtext 中的旧 API 和新 API 都可以预处理文本输入并准备数据,以便通过以下步骤训练/验证模型:
Train/validate/test split:生成训练/验证/测试数据集(如果有的话
Tokenization:将生硬的文本字符串拆分成单词表
Vocab:定义从tokens到indexs(索引)的 "合约"
Numericalize:将一列tokens转换为相应的索引
Batch:生成成批数据样本,并在必要时添加填充
需要注意的是,所有传统功能仍然可用,只是在 torchtext.legacy 中,而不是 torchtext 中。
Step 1: Create a dataset object
首先,我们创建一个用于情感分析的数据集。一个数据样本包含一个标签和一个文本字符串。
Legacy
在legacy代码中,Field类用于数据处理,包括标记器和编号。要查看数据集,用户需要首先设置TEXT/LABEL字段。
import torchtext
import torch
from torchtext.legacy import data
from torchtext.legacy import datasets
TEXT = data.Field()
LABEL = data.LabelField(dtype = torch.long)
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL) # datasets here refers to torchtext.legacy.datasets
您可以通过查看 Dataset.examples 打印原始数据。整个文本数据以一列标记的形式存储。
legacy_examples = legacy_train.examples
print(legacy_examples[0].text, legacy_examples[0].label)
New API
新的数据集API直接返回训练/测试数据集,而不需要预处理信息。每个split都是一个迭代器,逐行生成原始文本和标签。
from torchtext.datasets import IMDB
train_iter, test_iter = IMDB(split=('train', 'test'))
要打印原始数据,可以在 IterableDataset 上调用 next() 函数。
Step 2 Build the data processing pipeline
Legacy
Field类默认使用的tokenizer是内置的 python split() 函数。用户通过调用 data.get_tokenizer() 选择标记符,并将其添加到字段构造函数中。对于序列模型,通常会附加 <BOS>(句首)和 <EOS>(句末)标记,因此需要在 Field 类中定义特殊标记。
TEXT = data.Field(tokenize=data.get_tokenizer('basic_english'),
init_token='<SOS>', eos_token='<EOS>', lower=True)
LABEL = data.LabelField(dtype = torch.long)
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL) # datasets here refers to torchtext.legacy.datasets
现在,您可以从存储在预定义字段对象 TEXT 中的文本文件中创建词汇表。您必须将数据集传递给 build_vocab 函数,从而在 Field 对象中创建词汇表。字段对象会根据特定的数据分割建立词汇表(TEXT.vocab)。
TEXT.build_vocab(legacy_train)
LABEL.build_vocab(legacy_train)
使用词汇表对象可以做的事情
词汇表的总长度
String2Index (stoi)和 Index2String (itos)
包含出现 N 次以上的单词的特定用途词汇表
legacy_vocab = TEXT.vocab
print("The length of the legacy vocab is", len(legacy_vocab))
legacy_stoi = legacy_vocab.stoi
print("The index of 'example' is", legacy_stoi['example'])
legacy_itos = legacy_vocab.itos
print("The token at index 686 is", legacy_itos[686])
# Set up the mim_freq value in the Vocab class
TEXT.build_vocab(legacy_train, min_freq=10)
legacy_vocab2 = TEXT.vocab
print("The length of the legacy vocab is", len(legacy_vocab2))
New API
用户可以通过 data.get_tokenizer() 函数直接访问不同类型的 tokenizers。
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('basic_english')
为了获得更大的灵活性,用户可以直接使用 Vocab 类建立词汇表。例如,参数 min_freq 用于设置词汇的截止频率。像 <BOS> 和 <EOS> 这样的特殊字符可以分配给 Vocab 类构造函数中的特殊符号。
from collections import Counter
from torchtext.vocab import Vocab
train_iter = IMDB(split='train')
counter = Counter()
for (label, line) in train_iter:
counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=10, specials=('<unk>', '<BOS>', '<EOS>', '<PAD>'))
print("The length of the new vocab is", len(vocab))
new_stoi = vocab.stoi
print("The index of '<BOS>' is", new_stoi['<BOS>'])
new_itos = vocab.itos
print("The token at index 2 is", new_itos[2])
text_transform和label_transform都是可调用对象,例如这里的lambda函数,用于处理来自数据集迭代器的原始文本和标签数据。用户可以在text_transform中的句子中添加特殊符号<BOS>和<EOS>。
text_transform = lambda x: [vocab['<BOS>']] + [vocab[token] for token in tokenizer(x)] + [vocab['<EOS>']]
label_transform = lambda x: 1 if x == 'pos' else 0
# Print out the output of text_transform
print("input to the text_transform:", "here is an example")
print("output of the text_transform:", text_transform("here is an example"))
Step 3: Generate batch iterator
为了高效地训练模型,建议建立一个迭代器来批量生成数据。
Legacy
传统的 Iterator 类用于批处理数据集,并发送到目标设备(如 CPU 或 GPU)。
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL) # datasets here refers to torchtext.legacy.datasets
legacy_train_iterator, legacy_test_iterator = data.Iterator.splits(
(legacy_train, legacy_test), batch_size=8, device = device)
在 NLP 工作流程中,定义一个迭代器并将长度相近的文本放在一起也很常见。torchtext 库中的传统 BucketIterator 类最大程度地减少了所需的填充量。
from torchtext.legacy.data import BucketIterator
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)
legacy_train_bucketiterator, legacy_test_bucketiterator = data.BucketIterator.splits(
(legacy_train, legacy_test),
sort_key=lambda x: len(x.text),
batch_size=8, device = device)
New API
torch.utils.data.DataLoader 用于生成批处理数据。用户可以通过在 DataLoader 中定义一个带有 collate_fn 参数的函数来定制数据批处理。这里,在 collate_batch func 中,我们处理原始文本数据并添加填充,以动态匹配批次中最长的句子。
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
def collate_batch(batch):
label_list, text_list = [], []
for (_label, _text) in batch:
label_list.append(label_transform(_label))
processed_text = torch.tensor(text_transform(_text))
text_list.append(processed_text)
return torch.tensor(label_list), pad_sequence(text_list, padding_value=3.0)
train_iter = IMDB(split='train')
train_dataloader = DataLoader(list(train_iter), batch_size=8, shuffle=True,
collate_fn=collate_batch)
要将长度相近的文本分组,就像传统的 BucketIterator 类中介绍的那样,首先,我们随机创建多个 "池",每个池的大小为 batch_size * 100。然后,我们按长度对单个池中的样本进行排序。这个想法可以通过 PyTorch Dataloader 的 batch_sampler 参数简洁地实现。batch_sampler 接受 "Sampler "或 Iterable 对象,这些对象会产生下一批样本的索引。在下面的代码中,我们实现了一个生成器,它能生成相应批次数据长度相似的批次索引。
import random
train_iter = IMDB(split='train')
train_list = list(train_iter)
batch_size = 8 # A batch size of 8
def batch_sampler():
indices = [(i, len(tokenizer(s[1]))) for i, s in enumerate(train_list)]
random.shuffle(indices)
pooled_indices = []
# create pool of indices with similar lengths
for i in range(0, len(indices), batch_size * 100):
pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))
pooled_indices = [x[0] for x in pooled_indices]
# yield indices for current batch
for i in range(0, len(pooled_indices), batch_size):
yield pooled_indices[i:i + batch_size]
bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),
collate_fn=collate_batch)
print(next(iter(bucket_dataloader)))
Step 4: Iterate batch to train a model
在训练和验证模型的过程中,迭代批数据在传统的和新的应用程序接口几乎是一样的。
Legacy
传统的批处理迭代器可以使用 next() 方法进行迭代或执行。
next(iter(legacy_train_iterator))
New API
批处理迭代器可以使用 next() 方法迭代或执行。
# for idx, (label, text) in enumerate(train_dataloader):
# model(item)
# Or
next(iter(train_dataloader))