训练中途周末被人强制断电,仅保存了训练到一半的模型,用来做一波测试,效果好于预期,当然,现在还是个智障。预测代码如下,注释已经很详尽了,做个备注,用于学习。
# coding: utf-8
"""
@ File: Seq2Seq_inference.py
@ Brief: 翻译模型预测程序,使用训练完成的 checkpoint 数据
"""
import tensorflow as tf
import codecs
import sys
# 设置参数
# 读取 checkpoint 的路径,9000 表示是训练程序在第 9000 步保存的 checkpoint
CHECKPOINT_PATH = "./seq2seq_ckpt-6200"
# 模型参数,必须与训练时的模型参数保持一致
HIDDEN_SIZE = 1024 # LSTM 的隐藏层规模
NUM_LAYERS = 2 # 深层循环神经网络中 LSTM 结构的层数
SRC_VOCAB_SIZE = 10000 # 源语言词汇表大小
TRG_VOCAB_SIZE = 4000 # 目标语言词汇表大小
SHARE_EMB_AND_SOFTMAX = True # 在 Softmax 层和词向量层之间共享参数
# 词汇表文件
SRC_VOCAB = './en.vocab'
TRG_VOCAB = './zh.vocab'
# 词汇表中 <SOS> 和 <eos> 的 ID。在解码过程中需要用 <SOS> 作为第一步的输入,
# 并将检查是否是 <eos>,因此需要知道这两个符号的 ID
SOS_ID = 1
EOS_ID = 2
# 定义 NMTModel 类来描述模型
class NMTModel(object):
# 在模型的初始化函数中定义模型要用到的变量
def __init__(self):
# 定义编码器和解码器所使用的 LSTM 结构
self.enc_cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
for _ in range(NUM_LAYERS)]
)
self.dec_cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
for _ in range(NUM_LAYERS)]
)
# 为源语言和目标语言分别定义词向量
self.src_embedding = tf.get_variable(
"src_emb", [SRC_VOCAB_SIZE
最低0.47元/天 解锁文章
979

被折叠的 条评论
为什么被折叠?



