本文是参照原文:https://zhuanlan.zhihu.com/p/28643244之后发现一些问题,对操作过程记录。
实验环境: Windows 10 pro (64位)
所用代码为Python3*
用到的python包有:gensim, jieba
1.下载中文维基百科数据与预处理
首先下载数据。下载地址:https://dumps.wikimedia.org/zhwiki/latest/zhwiki-latest-pages-articles.xml.bz2
首先将xml的wiki数据转换为text格式,通过下面这个脚本(process_wiki.py)实现:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Pan Yang (panyangnlp@gmail.com)
# Copyrigh 2017
from __future__ import print_function
import logging
import os.path
import six
import sys
import io
from gensim.corpora import WikiCorpus
sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding='utf8') #改变标准输出的默认编码
if __name__ == '__main__':
program = os.path.basename(sys.argv[0])
logger = logging.getLogger(program)
logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
logging.root.setLevel(level=logging.INFO)
logger.info("running %s" % ' '.join(sys.argv))
# check and process input arguments
if len(sys.argv) != 3:
print("Using: python process_wiki.py enwiki.xxx.xml.bz2 wiki.en.text")
sys.exit(1)
inp, outp = sys.argv[1:3]
space = " "
i = 0
output = open(outp, 'w', encoding="utf-8")
wiki = WikiCorpus(inp, lemmatize=False, dictionary={})
for text in wiki.get_texts():
if six.PY3:
output.write(bytes(' '.join(text), 'utf-8').decode('utf-8') + '\n')
# ###another method###
# output.write(
# space.join(map(lambda x:x.decode("utf-8"), text)) + '\n')
else:
output.write(space.join(text) + "\n")
i = i + 1
if (i % 10000 == 0):
logger.info("Saved " + str(i) + " articles")
output.close()
logger.info("Finished Saved " + str(i) + " articles")
执行: python process_wiki.py zhwiki-latest-pages-articles.xml.bz2 wiki.zh.text
安装opencc(参照:https://www.twblogs.net/a/5b8bfc4d2b717718832f9739/zh-cn),然后将wiki.zh.text中的繁体字转化为简体字:
opencc -i wiki.zh.text -o wiki.zh.text.sim -c D:\opencc-1.0.4\share\opencc\t2s.json #这里-c后面写t2s.json的路径
用jieba分词
运行fenci.py得到分词后的结果文本。
##!/usr/bin/env python
## coding=utf-8
import jieba
import sys
import importlib
filePath='wiki.zh.text.sim'
fileSegWordDonePath ='wiki.zh.text.jian.seg.utf-8'
# read the file by line
fileTrainRead = []
#fileTestRead = []
with open(filePath,encoding="utf-8") as fileTrainRaw:
for line in fileTrainRaw:
fileTrainRead.append(line)
# define this function to print a list with Chinese
def PrintListChinese(list):
for i in range(len(list)):
print(list[i]),
# segment word with jieba
fileTrainSeg=[]
for i in range(len(fileTrainRead)):
importlib.reload(sys)
fileTrainSeg.append([' '.join(list(jieba.cut(fileTrainRead[i][9:-11],cut_all=False)))])
if i % 100 == 0 :
print (i)
PrintListChinese(fileTrainSeg[i])
# to test the segment result
PrintListChinese(fileTrainSeg[10])
# save the result
with open(fileSegWordDonePath,'wb') as fW:
for i in range(len(fileTrainSeg)):
fW.write(fileTrainSeg[i][0].encode('utf-8'))
fW.write('\n'.encode('utf-8'))
2.使用gensim训练Word2Vec
执行:python train_word2vec_model.py wiki.zh.text.jian.seg.utf-8 wiki.zh.text.model wiki.zh.text.vector
脚本train_word2vec_model.py如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Pan Yang (panyangnlp@gmail.com)
# Copyright 2017
from __future__ import print_function
import logging
import os
import sys
import multiprocessing
from gensim.models import Word2Vec
from gensim.models.word2vec import LineSentence
if __name__ == '__main__':
program = os.path.basename(sys.argv[0])
logger = logging.getLogger(program)
logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
logging.root.setLevel(level=logging.INFO)
logger.info("running %s" % ' '.join(sys.argv))
# check and process input arguments
if len(sys.argv) < 4:
print("Useing: python train_word2vec_model.py input_text "
"output_gensim_model output_word_vector")
sys.exit(1)
inp, outp1, outp2 = sys.argv[1:4]
model = Word2Vec(LineSentence(inp), size=200, window=5, min_count=5,
workers=multiprocessing.cpu_count())
model.save(outp1)
model.wv.save_word2vec_format(outp2, binary=False)
训练好的中文维基百科word2vec模型“wiki.zh.text.vector"的效果:
In [1]: import gensim
In [2]: model = gensim.models.Word2Vec.load("wiki.zh.text.model")
In [3]: model.most_similar(u"足球")
Out[3]:
[(u'\u8054\u8d5b', 0.6553816199302673),
(u'\u7532\u7ea7', 0.6530429720878601),
(u'\u7bee\u7403', 0.5967546701431274),
(u'\u4ff1\u4e50\u90e8', 0.5872289538383484),
(u'\u4e59\u7ea7', 0.5840631723403931),
(u'\u8db3\u7403\u961f', 0.5560152530670166),
(u'\u4e9a\u8db3\u8054', 0.5308005809783936),
(u'allsvenskan', 0.5249762535095215),
(u'\u4ee3\u8868\u961f', 0.5214947462081909),
(u'\u7532\u7ec4', 0.5177896022796631)]
In [4]: result = model.most_similar(u"足球")
In [5]: for e in result:
print e[0], e[1]
....:
联赛 0.65538161993
甲级 0.653042972088
篮球 0.596754670143
俱乐部 0.587228953838
乙级 0.58406317234
足球队 0.556015253067
亚足联 0.530800580978
allsvenskan 0.52497625351
代表队 0.521494746208
甲组 0.51778960228
训练得到的gensim Word2Vec 模型可用作Pre train embedding Layer 例如:
'''set network structure'''
embedding = gensim.models.Word2Vec.load("wiki.zh.text.model")
encoder_decoder = Sequential()
encoder_decoder.add(AttentionSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=tar_maxlen, input_shape=(input_maxlen, hidden_dim)))
encoder_decoder.add(Activation('softmax'))
encoder_decoder.compile(loss='categorical_crossentropy', optimizer="RMSprop")
'''train model'''
train_batch = 100
train_list_length=int(len(tar_list)/train_batch)
for epoch in range(0,800):
for i in range(0,train_batch-1):
tars_train=one_hot(tar_list[train_list_length*i:train_list_length*(i+1)], word_to_idx, idx_to_word, tar_maxlen, vocab_size)
word = get_Embedding(embedding,inputs_train[train_list_length*i:train_list_length*(i+1)],input_maxlen)
encoder_decoder.fit(word, tars_train, batch_size=10, epochs=1, verbose=1)
tars_train=one_hot(tar_list[train_list_length*(train_batch-1):], word_to_idx, idx_to_word, tar_maxlen, vocab_size)
word = get_Embedding(embedding,inputs_train[train_list_length*(train_batch-1):],input_maxlen)
encoder_decoder.fit(word, tars_train, batch_size=10, epochs=1, verbose=1)
print(epoch)
print('Train complete...')
用于得到样本Embedding vector的函数如下:
def get_Embedding(embedding,inputs_train,input_maxlen):
embedding_matrix = []
for sentence in inputs_train:
embedding_sentence = []
for word in sentence:
try:
embedding_sentence.append(embedding[word])
except Exception as err:
embedding_sentence.append(embedding["unkown"])
embedding_matrix.append(embedding_sentence)
return np.array(embedding_matrix)