参考https://zhuanlan.zhihu.com/p/39918126
1、首先,必须有一个待处理的文本数据集
json_dir = r'.\data'
json_file ='IUdata_trainval.json'
vocab_dir = r'.\data'
vocab_file= 'IUdata_vocab.pkl'
json_path = os.path.join(json_dir,json_file)
vocab_path = os.path.join(vocab_dir,vocab_file)
2、将文本数据集中的低频词替换成了未知词标记<unk>,我们可以自己指定一个频率阈值,我们认为没有超过这个频率的值为一些低频词,都设置为<unk>。
threshold = 1
generate_pkl = True
3、添加特殊词,统计词频
vocab_build(json_path, threshold, vocab_path, generate_pkl)
import os
import pickle
import json
#import jsonlines
from collections import Counter
class JsonReader(object):
def __init__(self, json_file, generate_pkl=False):
self.data = self.__read_json(json_file)
self.gen_pkl = generate_pkl
# for generating the pkl file
if self.gen_pkl:
self.keys = list(self.data.keys())
self.items = list(self.data.items())
def __read_json(self, filename):
with open(filename, 'r') as f:
data = json.load(f)
#data = jsonlines.Reader(f)
return data
def __getitem__(self, item):
# for generating the pkl file
if self.gen_pkl:
data = self.data[self.keys[item]]
else:
data = self.items[item]
return data
def __len__(self):
return len(self.data)
class Vocabulary(object):
def __init__(self):
self.word2idx = {}
self.id2word = {}
self.idx = 0
self.add_word('<pad>') #填充词汇
self.add_word('<start>') #句子开始
self.add_word('<end>') #句子结束
self.add_word('<unk>') #未知词
def add_word(self, word):
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.id2word[self.idx] = word
self.idx += 1
def get_word_by_id(self, id):
return self.id2word[id]
def __call__(self, word):
if word not in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def vocab_build(json_file, threshold, vocab_path, generate_pkl=False):
caption_reader = JsonReader(json_file, generate_pkl) #如果使用的不是json文件保存文本数据集,或者json文件的报错方式不一样,需要更改相应的代码
counter = Counter()
for i in caption_reader:
for items in i:
# 以‘ ’进行分割,为了保留标点符号,需要在标点符号前后加入‘ ’
text = items.replace('.', ' . ').replace(',', ' , ')
# 英文数据集,统一成小写,同时统计词汇的出现频次
counter.update(text.lower().split(' '))
#利用设置的阈值将出现频次过低的词筛除
words = [word for word, cnt in counter.items() if cnt > threshold and word != '']
vocab = Vocabulary() #添加特殊词汇
for word in words:
#print(word)
vocab.add_word(word) #将词汇加入词汇表中,给词汇编号
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f)
print("Finish!")
4、.pkl文件保存词汇表