在网上找了很久都没有找到Word2Vec增量训练的例子,通过尝试之后,终于实现了增量训练功能。分享出来供大家参考。
环境:
OS: win 10
Python: V3.5.4
核心代码:
# coding: utf-8
# Author: chen chong
# DateTime: 2018/11/11 15:11
import time
import os
import multiprocessing
from gensim.models import Word2Vec
from CCApp.Augur.db_service.db_service import DBService
from CCApp.Augur.cut_words_service.cut_words_service import CutWordsService
import warnings
warnings.filterwarnings(action='ignore', category=UserWarning, module='gensim')
class ModelTrainService:
def __init__(self):
os_name = os.name
if 'nt' != os_name:
prefix_path = '/home/chong/Tmp/'
else:
prefix_path = 'D:/workspace/gensim_folder/'
self.model_file_path = os.path.join(prefix_path, 'augur.model')
def _train(self, sentences):
"""
首次和增量训练实现
:param sentences: 语料列表的列表,形如: [['小程序', '程序员'],['代码']]
:return:
"""
if os.path.isfile(self.model_file_path): # model已存在,则进行增量训练
self.model = Word2Vec.load(self.model_file_path)
self.model.build_vocab(sentences, update=True)
else: # model不存在,首次训练
self.model = Word2Vec(sentences, size=100, window=10, min_count=1, workers=multiprocessing.cpu_count())
self.model.train(sentences, total_examples=len(sentences), epochs=self.model.iter)
self.model.save(self.model_file_path) # 训练结束后覆盖原来的model
def work(self):
db_service = DBService()
while True:
try:
ids, title_paragraphs = db_service.get_un_train_news()
if not ids:
for i in range(30):
time.sleep(60)
sentences = list()
for title, content in title_paragraphs:
if title:
title_words = CutWordsService.cut(title)
if title_words:
sentences.append(title_words)
if content:
content_words = CutWordsService.cut(content)
if content_words:
sentences.append(content_words)
if sentences:
self._train(sentences)
db_service.delete_has_train(ids)
print('>>>>> delete doc type by ids:', ids)
except Exception as e:
import datetime
import logging
from logging.handlers import SMTPHandler
logger = logging.getLogger()
smtp_handler = SMTPHandler('smtp.global-mail.cn', 'cc.chen@maxpr.com.cn', 'cc.chen@maxpr.com.cn', '机器学习模型训练失败通知', ('cc.chen@maxpr.com.cn', 'Iamcc001'))
logger.addHandler(smtp_handler)
now = datetime.datetime.now().strftime('%Y-%M-%d %H:%M:%S')
logger.error('datetime:{0} \n Exception:\n{1}'.format(now, e))
break
if '__main__' == __name__:
service = ModelTrainService()
service.work()