TensorFlow Seq2Seq模型样例:中英文翻译

训练中途周末被人强制断电,仅保存了训练到一半的模型,用来做一波测试,效果好于预期,当然,现在还是个智障。预测代码如下,注释已经很详尽了,做个备注,用于学习。

# coding: utf-8
"""
@ File: Seq2Seq_inference.py
@ Brief: 翻译模型预测程序,使用训练完成的 checkpoint 数据
"""

import tensorflow as tf
import codecs
import sys

# 设置参数
# 读取 checkpoint 的路径,9000 表示是训练程序在第 9000 步保存的 checkpoint
CHECKPOINT_PATH = "./seq2seq_ckpt-6200"

# 模型参数,必须与训练时的模型参数保持一致
HIDDEN_SIZE = 1024              # LSTM 的隐藏层规模
NUM_LAYERS = 2                  # 深层循环神经网络中 LSTM 结构的层数
SRC_VOCAB_SIZE = 10000          # 源语言词汇表大小
TRG_VOCAB_SIZE = 4000           # 目标语言词汇表大小
SHARE_EMB_AND_SOFTMAX = True    # 在 Softmax 层和词向量层之间共享参数

# 词汇表文件
SRC_VOCAB = './en.vocab'
TRG_VOCAB = './zh.vocab'

# 词汇表中 <SOS> 和 <eos> 的 ID。在解码过程中需要用 <SOS> 作为第一步的输入,
# 并将检查是否是 <eos>,因此需要知道这两个符号的 ID
SOS_ID = 1
EOS_ID = 2



# 定义 NMTModel 类来描述模型
class NMTModel(object):
    # 在模型的初始化函数中定义模型要用到的变量
    def __init__(self):
        # 定义编码器和解码器所使用的 LSTM 结构
        self.enc_cell = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
            for _ in range(NUM_LAYERS)]
        )
        self.dec_cell = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
             for _ in range(NUM_LAYERS)]
        )

        # 为源语言和目标语言分别定义词向量
        self.src_embedding = tf.get_variable(
            "src_emb", [SRC_VOCAB_SIZE
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值