转载博客:http://blog.csdn.net/thriving_fcl/article/details/74165062
相关代码如下(自己加了些注释):
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib.seq2seq import *
from tensorflow.python.layers.core import Dense
class Seq2SeqModel(object):
def __init__(self, rnn_size, layer_size, encoder_vocab_size,
decoder_vocab_size, embedding_dim, grad_clip, is_inference=False):
#embedding_dim是词向量的维度,encoder_vocab_size是输入单词的数目,decoder_vocab_size是目标单词的数目
# define inputs
self.input_x = tf.placeholder(tf.int32, shape=[None, None], name='input_ids')
# define embedding layer
with tf.variable_scope('embedding'):
encoder_embedding = tf.Variable(tf.truncated_normal(shape=[encoder_vocab_size, embedding_dim], stddev=0.1),
name='encoder_embedding')
decoder_embedding = tf.Variable(tf.truncated_normal(shape=[decoder_vocab_size, embedding_dim], stddev=0.1),
name='decoder_embedding')
# define encoder
with tf.variable_scope('encoder'):
encoder = self._get_simple_lstm(rnn_size, layer_size)
# 得到输入的词嵌入
with tf.device('/cpu:0'):
input_x_embedded = tf.nn.embedding_lookup(encoder_embedding, self.input_x)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder, input_x_embedded, dtype=tf.float32)
# define helper for decoder
if is_inference:
# 测试阶段
self.start_tokens = tf.placeholder(tf.int32, shape=[None], name='start_tokens')
self.end_token = tf.placeholder(tf.int32, name='end_token')
# decoder_embedding 为输出词典
# GreedyEmbeddingHelper(embedding,start_tokens,end_token)
helper = GreedyEmbeddingHelper(decoder_embedding, self.start_tokens, self.end_token)
else:
# 训练阶段
self.target_ids = tf.placeholder(tf.int32, shape=[None, None], name='target_ids')
self.decoder_seq_length = tf.placeholder(tf.int32, shape=[None], name='batch_seq_length')
# 得到目标的词嵌入
with tf.device('/cpu:0'):
target_embeddeds = tf.nn.embedding_lookup(decoder_embedding, self.target_ids)
# TrainingHelper(inputs,sequence_length,time_major=False,name=None)
# inputs:对应Decoder框架图中的embedded_input,
# time_major=False的时候,inputs的shape就是[batch_size, sequence_length, embedding_size] ,
# time_major=True时,inputs的shape为[sequence_length, batch_size, embedding_size]
# sequence_length:当前batch中每个序列的长度(self._batch_size = array_ops.size(sequence_length))
helper = TrainingHelper(target_embeddeds, self.decoder_seq_length)
with tf.variable_scope('decoder'):
# Dense(units,activation=None,use_bias=True) 其中units为输出空间的维度
fc_layer = Dense(decoder_vocab_size)
decoder_cell = self._get_simple_lstm(rnn_size, layer_size)
# BasicDecoder(cell, helper, initial_state, output_layer=None)
# cell:在这里就是一个多层LSTM的实例,与定义encoder时无异
# helper:这里只是简单说明是一个Helper实例
# initial_state:encoder的final state
# output_layer:对应的就是框架图中的Dense_Layer
decoder = BasicDecoder(decoder_cell, helper, encoder_state, fc_layer)
# BasicDecoder的作用就是定义一个封装了decoder应该有的功能的实例,根据Helper实例的不同,这个decoder可以实现不同的功能,
# 比如在train的阶段,不把输出重新作为输入,而在inference阶段,将输出接到输入。
# 执行动态解码decoder
# tf.contrib.seq2seq.dynamic_decode(decoder,output_time_major=False) 其中decoder是一个decoder实例
# 返回:(final_outputs, final_state, final_sequence_lengths)
logits, final_state, final_sequence_lengths = dynamic_decode(decoder)
if not is_inference:
# 训练 需要更新参数信息
# 将目标词id展成一维
targets = tf.reshape(self.target_ids, [-1])
logits_flat = tf.reshape(logits.rnn_output, [-1, decoder_vocab_size])
print 'shape logits_flat:{}'.format(logits_flat.shape)
print 'shape logits:{}'.format(logits.rnn_output.shape)
self.cost = tf.losses.sparse_softmax_cross_entropy(targets, logits_flat)
# define train op
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), grad_clip)
optimizer = tf.train.AdamOptimizer(1e-3)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))
else:
# 测试,返回概率
self.prob = tf.nn.softmax(logits)
# 返回一个深层的LSTM单元
def _get_simple_lstm(self, rnn_size, layer_size):
# rnn_size : the number of units in LSTM
# layer_size : 深层LSTM的层数
lstm_layers = [tf.contrib.rnn.LSTMCell(rnn_size) for _ in xrange(layer_size)]
return tf.contrib.rnn.MultiRNNCell(lstm_layers)