基于PyTorch的LSTM语言模型(Language Model)中字典(Vocabulary)大小限制(例如5000以内)的基本方法

这个标题比较长,其实需求很明确:在一些最简单的PyTorch的语言模型model中,原项目有时候并没有提供限制Vocabulary大小的功能,但这个又是大家常见的需求,所以我用最简单的方式总结一下:

在这里给出的例子是可以直接运行的:

https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/02-intermediate/language_model

但是我们可以看到,其原有的data_utils.py文件里并没有提供限制Vocabulary大小的功能,这里我们假定需要把Vocabulary限制在5000,下面这段代码就可以在原基础上实现:

import torch
import os


class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
    
    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    def __init__(self):
        self.dictionary = Dictionary()

    def get_data(self, path, batch_size=20, max_vocab_size=5000):
        
        raw_vocab={}
        special_words = ['<unk>']
        
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    if word in raw_vocab:
                        raw_vocab[word]+=1
                    else:
                        raw_vocab[word]=1
        if('<unk>' in raw_vocab):
            vocab=sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
        else:
            vocab=special_words+sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
        print('Original Vocabulary Size is %d'%len(vocab))
        if(len(vocab)>max_vocab_size):
            vocab = vocab[ : max_vocab_size]
        
        
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    if(word in vocab): 
                        self.dictionary.add_word(word)  
                    else: 
                        self.dictionary.add_word('<unk>') 
        
        print('The Generated Vocabulary Size is %d'%self.dictionary.__len__())
        
        # Tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    if(word in vocab):
                        ids[token] = self.dictionary.word2idx[word]
                    else:
                        ids[token] = self.dictionary.word2idx['<unk>']
                    token += 1
        
        num_batches = ids.size(0) // batch_size
        ids = ids[:num_batches*batch_size]
        return ids.view(batch_size, -1)

这里唯一需要注意的,就是原语料库里可能就有<unk>,在没有的情况下才需要加上<unk>。其他内容都很简单,配合上面网址中的代码和数据,即可进行测试。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值