序列到序列的RNN语言机器翻译模型的tensorflow代码解析。
0x00.前言
(原文发表在博客,欢迎访问
这份代码最开始在基于RNN的语言模型与机器翻译NMT看到。本着溯本求源的心态,我搜了一下代码,找到了Brok-Bucholtz/P4-Beta/language_translation/dlnd_language_translation.ipynb
。
(更完整版在Language Translation
因为工作中需要,所以要对其代码及神经网络结构有所了解。但是其中涉及了很多seq2seq函数,我没有接触过。所以这里对其进行一次代码分析。
Ps:这份代码版本在貌似在1.0,而tensorflow到目前已经到1.3了。最重要的是,其中很多函数已经发生了变化。但是,作为学习,我们依旧可以从中学到很多东西。最新代码可以参考从Encoder到Decoder实现Seq2Seq模型 - 知乎专栏
0x01.代码解析
1.数据预处理
定义了两个函数text_to_ids
解析文本,sentence_to_ids
将一句话中的词转换为对应的id。
2.输入
定义函数model_inputs
来创建tf.placeholder
。
3.编码
RNN的编码层:
def encoding_layer(rnn_inputs, rnn_size, num_layers, keep_prob):
"""
Create encoding layer
:param rnn_inputs: Inputs for the RNN
:param rnn_size: RNN Size
:param num_layers: Number of layers
:param keep_prob: Dropout keep probability
:return: RNN state
"""
# TODO: Implement Function
# Encoder
# 首先创建多层lstm
enc_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(rnn_size)] * num_layers)
# 加dropout层
enc_cell = tf.contrib.rnn.DropoutWrapper(enc_cell, output_keep_prob=keep_prob)
# 动态rnn
rnn, state = tf.nn.dynamic_rnn(enc_cell, rnn_inputs, dtype=tf.float32)
return state
4.解码-训练层
下面介绍几个其中用到的函数:
tf.contrib.seq2seq.simple_decoder_fn_train(encoder_state, name=None)
在dynamic_rnn_decoder
中使用的序列到序列模型的简单解码器。simple_decoder_fn_train
是序列到序列模型的简单训练函数。当dynamic_rnn_decoder
处于训练模式时应该使用它。tf.contrib.seq2seq.dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None, name=None)
由RNNCell和解码器功能规定的序列到序列模型的动态RNN(drnn)解码器。有训练和推理两种模式。
def decoding_layer_train(encoder_state, dec_cell, dec_embed_input, sequence_length, decoding_scope,
output_fn, keep_prob):
"""
Create a decoding layer for training
:param encoder_state: Encoder State
:param dec_cell: Decoder RNN Cell
:param dec_embed_input: Decoder embedded input
:param sequence_length: Sequence Length
:param decoding_scope: TenorFlow Variable Scope for decoding
:param output_fn: Function to apply the output layer
:pa