spaCy3.3.0 torchtext0.13.1预处理
import torch
import spacy
from torchtext.datasets import Multi30k
from collections import Counter
from torchtext.vocab import vocab
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#分词函数
def tokenize_en(text):
return [tok.text.lower() for tok in spacy_en.tokenizer(text)][::-1]
def tokenize_de(text):
return [tok.text.lower() for tok in spacy_de.tokenizer(text)]
#加载spacy模型
spacy_en=spacy.load('en_core_web_sm')
spacy_de=spacy.load('de_core_news_sm')
#生成vocabulary
counter_en = Counter()
counter_de=Counter()
train_iter, test_iter = Multi30k(split=('train', 'test'))
diter = iter(train_iter)
while True:
try:
text = next(diter)
except StopIteration:
diter = iter(train_iter)
break
counter_en.update(tokenize_en(text[1]))
counter_de.update(tokenize_de(text[0]))
vocab_de = vocab(counter_en, min_freq=2, specials=('<unk>', '<BOS>', '<EOS>', '<PAD>'))
vocab_en = vocab(counter_de, min_freq=2, specials=('<unk>', '<BOS>', '<EOS>', '<PAD>'))
#设置vocabulary默认index
vocab_de.set_default_index(vocab_de.get_stoi()['<unk>'])
vocab_en.set_default_index(vocab_en.get_stoi()['<unk>'])
#text2index 添加首尾
en_transfor = lambda x: [vocab_de['<BOS>']] + [vocab_de[token] for token in tokenize_en(x)] + [vocab_de['<EOS>']]
de_transfor = lambda x: [vocab_en['<BOS>']] + [vocab_en[token] for token in tokenize_en(x)] + [vocab_en['<EOS>']]
#batch后处理
def collate_batch(batch):
trg_list,src_list=[],[]
for src,trg in batch:
#print(src,trg)
trg_list.append(torch.tensor(en_transfor(trg)))
src_list.append(torch.tensor(de_transfor(src)))
#填充batch中元素,长度相同
src_list=pad_sequence(src_list, padding_value=vocab_de.get_stoi()['<PAD>'])
trg_list = pad_sequence(trg_list, padding_value=vocab_en.get_stoi()['<PAD>'])
return src_list,trg_list
if __name__ == '__main__':
dl=DataLoader(train_iter,batch_size=8,collate_fn=collate_batch)
print(next(iter(dl)))
参考https://github.com/pytorch/text/blob/master/examples/legacy_tutorial/migration_tutorial.ipynb