这个标题比较长,其实需求很明确:在一些最简单的PyTorch的语言模型model中,原项目有时候并没有提供限制Vocabulary大小的功能,但这个又是大家常见的需求,所以我用最简单的方式总结一下:
在这里给出的例子是可以直接运行的:
https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/02-intermediate/language_model
但是我们可以看到,其原有的data_utils.py文件里并没有提供限制Vocabulary大小的功能,这里我们假定需要把Vocabulary限制在5000,下面这段代码就可以在原基础上实现:
import torch
import os
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
class Corpus(object):
def __init__(self):
self.dictionary = Dictionary()
def get_data(self, path, batch_size=20, max_vocab_size=5000):
raw_vocab={}
special_words = ['<unk>']
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
if word in raw_vocab:
raw_vocab[word]+=1
else:
raw_vocab[word]=1
if('<unk>' in raw_vocab):
vocab=sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
else:
vocab=special_words+sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
print('Original Vocabulary Size is %d'%len(vocab))
if(len(vocab)>max_vocab_size):
vocab = vocab[ : max_vocab_size]
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
if(word in vocab):
self.dictionary.add_word(word)
else:
self.dictionary.add_word('<unk>')
print('The Generated Vocabulary Size is %d'%self.dictionary.__len__())
# Tokenize the file content
ids = torch.LongTensor(tokens)
token = 0
with open(path, 'r') as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
if(word in vocab):
ids[token] = self.dictionary.word2idx[word]
else:
ids[token] = self.dictionary.word2idx['<unk>']
token += 1
num_batches = ids.size(0) // batch_size
ids = ids[:num_batches*batch_size]
return ids.view(batch_size, -1)
这里唯一需要注意的,就是原语料库里可能就有<unk>,在没有的情况下才需要加上<unk>。其他内容都很简单,配合上面网址中的代码和数据,即可进行测试。