torchtext
的详细用法请参考上一期:Torchtext 0.12+新版API学习与使用示例(1)
构造embedding的思路也很简单:
- 把语料训练成
torchtext
对应的vocab
- 然后对于输入的句子,进行如下转换:文本->
vocab id
->embedding
(这里借助nn.Embedding
进行转换)
构造DataLoader
之前要先构造DataSet
,DataSet
与DataLoader
的基础内容请参考:使用Pytorch DataLoader快捷封装训练数据、测试数据的X与Y
示例代码
from torchtext.vocab import vocab
from collections import Counter, OrderedDict
from torch.utils.data import Dataset, DataLoader
from torchtext.transforms import VocabTransform
import torch.nn as nn
import torch
class TextDataSet(Dataset):
def __init__(self, text_list, word_hidden=6):
"""
使用新版API的一个简单的TextDataSet
:param text_list: 语料的全部句子
:param word_hidden: 词向量的长度
"""
total_word_list = []
for _ in text_list: # 将嵌套的列表([[xx,xx],[xx,xx]...])拉平 ([xx,xx,xx...])
total_word_list += _.split(" ")
counter = Counter(total_word_list) # 统计计数
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) # 构造成可接受的格式:[(单词,num), ...]
ordered_dict = OrderedDict(sorted_by_freq_tuples)
# 开始构造 vocab
my_vocab = vocab(ordered_dict, specials=["<UNK>", "<SEP>"]) # 单词转token,specials里是特殊字符,可以为空
vocab_transform = VocabTransform(my_vocab)
# 开始构造DataSet
self.text_list = text_list # 原始文本
self.vocab = my_vocab
self.vocab_transform = vocab_transform
self._len = len(text_list) # 文本量
self.embedding = nn.Embedding(len(my_vocab), word_hidden) # 为每个词准备的词向量
def __getitem__(self, id_index): # 每次循环的时候返回的值
sentence = self.text_list[id_index]
word_ids = self.vocab_transform(sentence.split(' '))
word_embedding = self.embedding(torch.Tensor(word_ids).long())
return word_embedding, sentence
def __len__(self):
return self._len
def main():
sentence_list = [ # 假设这是全部的训练语料
"nlp is natural language processing strives",
"nlp build machines that understand",
"nlp model respond to text or voice data and respond with text",
]
text_dataset = TextDataSet(sentence_list) # 构造 DataSet
data_loader = DataLoader(text_dataset, batch_size=1) # 将DataSet封装成DataLoader
for word_embedding, sentence in data_loader:
print("====================================")
print("原句是:", sentence)
print("对应的Embedding:", word_embedding)
if __name__ == '__main__':
main()