train

# -- encoding:utf-8 --

import os
import tensorflow as tf

from nets.w2vnet import CBOWNetwork, SkipGramNetwork
from utils.data_utils import DataManager

# parameters
# =====================================
# 模型训练数据参数
tf.flags.DEFINE_string("data_path", "./data/train.cbow.data", "训练数据所在的磁盘路径!!")
tf.flags.DEFINE_string("dictionary_path", "./data/dictionary.json", "词典数据所在的磁盘路径!!")

# =====================================
# 网络结构的参数
tf.flags.DEFINE_string("network_name", "w2v", "网络结构名称!!")
tf.flags.DEFINE_integer("embedding_size", 128, "Embedding的维度大小!!")

# =====================================
# Word2Vec的参数
tf.flags.DEFINE_string("structure", "cbow", "Word2Vec的结构!!")
tf.flags.DEFINE_integer("window", 4, "窗口大小!!")
tf.flags.DEFINE_boolean("cbow_mean", True, "CBOW结构中,合并上下文数据的时候,是否计算均值!!")

# =====================================
# 训练参数
tf.flags.DEFINE_integer("max_epoch", 10, "最大迭代的Epoch的次数!!")
tf.flags.DEFINE_integer("batch_size", 1000, "批次大小!!")
tf.flags.DEFINE_integer("num_sampled", 100, "负采样的类别数目!!")
tf.flags.DEFINE_string("optimizer_name", "adam", "优化器名称!!")
tf.flags.DEFINE_float("learning_rate", 0.001, "学习率!!")
tf.flags.DEFINE_float("regularization", 0.00001, "L2 Loss惩罚项系数!!")

# ====================================
# 模型持久化参数
tf.flags.DEFINE_string("checkpoint_dir", "./running/model", "模型持久化文件路径!!")
tf.flags.DEFINE_integer("checkpoint_per_batch", 100, "给定模型持久化的间隔批次大小!!")

# ====================================
# 模型可视化参数
tf.flags.DEFINE_string("summary_dir", "./running/graph", "模型可视化数据存储路径!!")

FLAGS = tf.flags.FLAGS


def main(_):
    # 0. 模型参数校验
    if not os.path.exists(FLAGS.data_path):
        raise Exception("数据文件夹不存在,请检查参数!!!")
    if not os.path.exists(FLAGS.dictionary_path):
        raise Exception("词典数据文件夹不存在,请检查参数!!!")
    assert FLAGS.structure in ['cbow', 'skipgram'], "仅支持cbow和skipgram这两个Word2Vec结构,请检查参数!!"
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.summary_dir):
        os.makedirs(FLAGS.summary_dir)

    # 一、训练数据加载
    tf.logging.info("开始训练数据加载....")
    train_data_manager = DataManager(
        data_path=FLAGS.data_path,  # 数据所在的磁盘路径
        dictionary_path=FLAGS.dictionary_path,  # 词典所在的磁盘路径
        window=FLAGS.window,  # 窗口大小
        structure=FLAGS.structure,  # Word2Vec的结构
        batch_size=FLAGS.batch_size,  # 批次大小
        encoding='utf-8-sig',
        shuffle=True
    )

    # 二、网络的构建&训练
    with tf.Graph().as_default():
        # 一、网络结构的构建
        if FLAGS.structure == 'cbow':
            tf.logging.info("开始构建CBOW的模型结构.....")
            model = CBOWNetwork(
                name=FLAGS.network_name,  # 网络名称
                num_sampled=FLAGS.num_sampled,  # 负采样的类别数目
                window=FLAGS.window,  # 窗口大小
                vocab_size=train_data_manager.word_size,  # 词汇数目
                embedding_size=FLAGS.embedding_size,  # embedding的维度大小
                is_mean=FLAGS.cbow_mean,  # cbow结构中,对于上下文单词的特征信息,如何合并(True:均值;False:和)
                regularization=FLAGS.regularization,  # 惩罚性系数
                optimizer_name=FLAGS.optimizer_name,  # 优化器
                learning_rate=FLAGS.learning_rate,  # 学习率
                checkpoint_dir=FLAGS.checkpoint_dir  # 模型持久化路径
            )
        else:
            tf.logging.info("开始构建SkipGram的模型结构.....")
            model = SkipGramNetwork(
                name=FLAGS.network_name,  # 网络名称
                num_sampled=FLAGS.num_sampled,  # 负采样的类别数目
                window=FLAGS.window,  # 窗口大小
                vocab_size=train_data_manager.word_size,  # 词汇数目
                embedding_size=FLAGS.embedding_size,  # embedding的维度大小
                regularization=FLAGS.regularization,  # 惩罚性系数
                optimizer_name=FLAGS.optimizer_name,  # 优化器
                learning_rate=FLAGS.learning_rate,  # 学习率
                checkpoint_dir=FLAGS.checkpoint_dir  # 模型持久化路径
            )
        # 1.1. 前向网络的构建
        tf.logging.info("开始构建前向网络结构....")
        model.interface()
        # 1.2 损失函数获取
        tf.logging.info("开始构建损失函数....")
        loss = model.losses()
        # 1.3 基于损失函数构建优化器以及训练对象
        tf.logging.info("开始构建优化器以及训练对象....")
        _, train_op = model.optimizer(loss=loss)
        # 1.4 构建可视化的相关信息
        summary_op = tf.summary.merge_all()
        writer = tf.summary.FileWriter(logdir=FLAGS.summary_dir, graph=tf.get_default_graph())

        # 二、模型训练
        with tf.Session() as sess:
            # 2.1. 模型参数初始化
            tf.logging.info("开始进行模型参数初始化....")
            model.restore(session=sess)

            # 2.3 迭代训练
            for epoch in range(FLAGS.max_epoch):
                for batch_x, batch_y in train_data_manager:
                    _, _loss, _step, _summary = sess.run([train_op, loss, model.global_step, summary_op], feed_dict={
                        model.input_x: batch_x,
                        model.target: batch_y
                    })
                    print("Epoch:{}, Step:{}, Loss:{:g}".format(epoch + 1, _step, _loss))
                    writer.add_summary(summary=_summary, global_step=_step)

                    # 进行判断&模型持久化
                    if _step % FLAGS.checkpoint_per_batch == 0:
                        model.save(session=sess)

            # 所有数据执行完
            model.save(session=sess)
            writer.close()


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)  # 日志级别定义
    tf.app.run()  # 运行

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值