import json
import time
from collections import OrderedDict
from gensim.models import KeyedVectors
from annoy import AnnoyIndex
t1 = time.time()
print('start load tc')
tc_wv_model = KeyedVectors.load_word2vec_format(
'../data/Tencent_AILab_ChineseEmbedding/Tencent_AILab_ChineseEmbedding.txt', binary=False)
print('load cost:{:.3f}'.format(time.time() - t1))
# 把txt文件里的词和对应的向量,放入有序字典
word_index = OrderedDict()
t2 = time.time()
for counter, key in enumerate(tc_wv_model.key_to_index.keys()):
word_index[key] = counter
print('OrderedDict cost:{:.3f}'.format(time.time() - t2))
# 本地保存
t3 = time.time()
with open('tc_word_index.json', 'w') as fp:
json.dump(word_index, fp)
print('save tc_word_index cost:{:.3f}'.format(time.time() - t3))
# 腾讯词向量是两百维的
t4 = time.time()
tc_index = AnnoyIndex(200, metric='angular')
i = 0
for key in tc_wv_model.key_to_index.keys():
v = tc_wv_model[key]
tc_index.add_item(i, v)
i += 1
tc_index.build(10)
print('build tc_index tree cost:{:.3f}'.format(time.time() - t4))
# 将这份index存到硬盘
t5 = time.time()
tc_index.save('tc_index_build10.index')
print('save tc_index tree cost:{:.3f}'.format(time.time() - t5))
t5 = time.time()
with open('tc_word_index.json', encoding='utf-8') as fp:
word_index = json.load(fp)
print('load tc_word_index.json cost:{:.3f}'.format(time.time() - t5))
t6 = time.time()
tc_index = AnnoyIndex(200, metric='angular')
tc_index.load('tc_index_build10.index')
print('load tc_index_build10.index cost:{:.3f}'.format(time.time() - t6))
# 反向id==>word映射词表
t7 = time.time()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
print('get reverse_word_index cost:{:.3f}'.format(time.time() - t7))
# get_nns_by_item基于annoy查询词最近的10个向量,返回结果是个list,里面元素是索引
t8 = time.time()
for item in tc_index.get_nns_by_item(word_index[u'优惠'], 10):
print(reverse_word_index[item]) # 用每个索引查询word
print('get_nns_by_item cost:{:.3f}'.format(time.time() - t8))
注:该代码段参考某大佬的博客,但是链接暂时找不到了,后面找到了会补充下,不好意思了。
使用相关依赖:
gensim==4.1.2
annoy==1.17.0
tqdm==4.62.3
上述代码,会把882万词向量,从txt文件中读取出来,处理后生成两个结果文件:
tc_word_index.json tc_index_build10.index
import time, json, logger, AnnoyIndex
class TencentAIChiEmbedding(object):
def __init__(self, word_index_path, tc_index_path):
self._word_index = self.load_word_index(word_index_path)
self._tc_index = self.load_tc_index(tc_index_path)
self._reverse_word_index = self.gen_reverse_word_index()
def load_word_index(self, word_index_path):
word_index = None
try:
st = time.time()
with open(word_index_path, encoding='utf-8') as fp:
word_index = json.load(fp)
logger.info('load {} cost:{:.3f}'.format(word_index_path, time.time() - st))
except Exception as e:
logger.error('load_word_index error:')
logger.exception(e)
return word_index
def load_tc_index(self, tc_index_path):
tc_index = None
try:
st = time.time()
tc_index = AnnoyIndex(200, metric='angular')
tc_index.load(tc_index_path) # 'tc_index_build10.index'
logger.info('load {} cost:{:.3f}'.format(tc_index_path, time.time() - st))
except Exception as e:
logger.error('load_tc_index error:')
logger.exception(e)
return tc_index
def gen_reverse_word_index(self):
reverse_word_index = None
try:
st = time.time()
reverse_word_index = dict([(value, key) for (key, value) in self._word_index.items()]) # 反向id==>word映射词表
logger.info('get reverse_word_index cost:{:.3f}'.format(time.time() - st))
except Exception as e:
logger.error('gen_reverse_word_index error:')
logger.exception(e)
return reverse_word_index
def get_simi_words(self, keyword, topn=10):
"""
基于annoy查询词最近的10个向量,返回结果是个list,里面元素是索引
"""
simi_words = []
try:
st = time.time()
for item in self._tc_index.get_nns_by_item(self._word_index[keyword], topn):
simi_words.append(self._reverse_word_index[item]) # 用每个索引查询word
logger.info('get_nns_by_item cost:{:.3f}. keyword: {}; simi_words: {}'.format(time.time() - st, keyword, simi_words))
except Exception as e:
logger.error('get_simi_words error:')
logger.exception(e)
return simi_words
if __name__=='__main__':
logger.info('Initializing tencent word vec...')
st = time.time()
TencentEmbedding = TencentAIChiEmbedding('../../data/Tencent_AILab_ChineseEmbedding/tc_word_index.json',
'../../data/Tencent_AILab_ChineseEmbedding/tc_index_build10.index')
logger.info('Initialize tencent word vec done. cost: {:.3f}'.format(time.time() - st))
logger.info(TencentEmbedding.get_simi_words('免单'))
下面这个类为使用上面生成的2个结果文件基础上,构建的一个生成同义词的工具类,稍微修改下文件目录,应该可以直接使用。