一、Helper类
1.1 简单介绍
目前Seq2Seq包含的Helper主要有TrainingHelper和GreedyEmbeddingHelper.
- TrainingHelper:如名字所说, 适用在训练阶段,功能是: Teacher Forcing.
- GreddyEmbeddingHelper: 适用在测试阶段。
1.2 参数结构
TrainingHelper
tf.contrib.seq2seq.TrainingHelper(inputs, # (batch_size, seq_len, dim)
sequence_length, # (batch_size, )
time_major=False,
name=None)
GreedyEmbeddingHelper
tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, # 用来做embedding_lookup, 和vocab对应
start_tokens,
end_tokens)
以下是参数的意思:
- embedding: 用来做embedding_lookup, 对应vocab.
- start_tokens: [batch_size, ]
- end_tokens: int32 scalar 注意是标量.
1.3 基本使用
# [go]
tokens_go = tf.ones([config.batch_size], dtype=tf.int32) * w2i_target["_GO"]
# Embedding
decoder_embedding = tf.Variable(tf.random_uniform([config.target_vocab_size, config.embedding_dim]), dtype=tf.float32, name='decoder_embedding')
# 建立decoder_cell
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
# 创建helper
if useTeacherForcing:
# 训练阶段的 TeacherForcing.
decoder_inputs = tf.concat([tf.reshape(tokens_go,[-1,1]), self.seq_targets[:,:-1]], 1)
helper =tf.contrib.seq2seq.TrainingHelper(tf.nn.embedding_lookup(decoder_embedding, decoder_inputs), self.seq_targets_length)
else:
# 推理阶段
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embedding, tokens_go, w2i_target["_EOS"])
# 使用BasicDecoder封装Cell.
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=tf.layers.Dense(config.target_vocab_size))
# 使用dynamic_decoder数据结果.
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))
二、Attention使用
2.1 基本使用
encoder_state = .... # seq2seq最后部分的state.
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
if useAttention:
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=config.hidden_dim, memory=encoder_outputs, memory_sequence_length=self.seq_inputs_length)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism)
# 初始化状态, 先zero(), 再clone().
decoder_initial_state = decoder_cell.zero_state(batch_size=config.batch_size, dtype=tf.float32)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=tf.layers.Dense(config.target_vocab_size))
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))
2.2 Attention实现
def attn(self, hidden, encoder_outputs):
# hidden: B * D
# encoder_outputs: B * S * D
attn_weights = tf.matmul(encoder_outputs, tf.expand_dims(hidden, 2))
# attn_weights: B * S * 1
attn_weights = tf.nn.softmax(attn_weights, axis=1)
context = tf.squeeze(tf.matmul(tf.transpose(encoder_outputs, [0,2,1]), attn_weights))
# context: B * D
return context
三、Beam Search
3.1 简单介绍
Beam Search,只在test阶段使用。
如果使用了Beam Search, 在每个时刻会选择top K的单词作为这个时刻的输出,逐一作为下一时刻的输入参与下一时刻的预测,然后再从这K*L (L为词表大小) 个结果中选top K作为下个时刻的输出。
3.2 参数结构
tf.contrib.seq2seq.BeamSearchDecoder(
cell, embedding, start_tokens, end_token,
initial_state, beam_width, output_layer=None,
length_penalty_weight=0.0, coverage_penalty_weight=0.0, reorder_tensor_arrays=True)
以下是参数的解释
- cell: 一个RNNCell 实例
- embedding: 用来embedding_lookup.
- start_tokens: [batch_size, ]
- end_token: int32 scalar标量.
- length_penalty_weight: 惩罚长句
- coverage_penalty_weight: 惩罚重复
- reorder_tensor_arrays: 如果设置为True, cell state将会根据Beam Search重新排序.
如果使用了AttentionWrapper, converage_penalty_weight是建议去使用的.它鼓励decoder去覆盖所有的inputs.
3.3 基本使用
以下是使用了AttentionWrapper、BeamSearch的代码。
with tf.variable_scope('decoder'):
beam_width = 5
memory = encoder_outputs
if mode != 'infer':
bs = batch_size
else:
memory = tf.contrib.seq2seq.tile_batch(memory, beam_width)
encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
bs = batch_size * beam_width
input_seq_lens = tf.contrib.seq2seq.tile_batch(input_seq_lens, beam_width)
# 注意力机制
attention = tf.contrib.seq2seq.LuongAttention(hidden_size, memory, input_seq_lens, scale=True)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
rnn_cell = tf.contrib.seq2seq.AttentionWrapper(rnn_cell, attention, hidden_size, name='attention')
decoder_initial_state = cell.zero_state(bs, tf.float32).clone(cell_state=encoder_state)
# 输出层
with tf.variable_scope('projected'):
output_layer = tf.layers.Dense(len(word2id_en), use_bias=False, kernel_initializer=k_initializer)
if mode == 'infer':
start = tf.ones([batch_size], dtype=tf.int32) * w2i_target["_GO"] # 注意这里:是batch_size, 而不是bs.
decoder = tf.contrib.seq2seq.BeamSearchDecoder(rnn_cell, decoder_embedding, start, word2id['</s>'], decoder_initial_state, beam_width, output_layer)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=2*tf.reduce_max(input_seq_lens))
sample_ids = outputs.predicted_ids # ? 这个是...??
else:
helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding, tgt_seq_lens)
decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, decoder_initial_state, output_layer)
output, final_context_state, _ = tf.contrib.seq2seq.dynamic_decoder(decoder, tgt_seq_lens)
以下是仅仅使用BeamSearch的代码:
tokens_go = tf.ones([config.batch_size], dtype=tf.int32) * w2i_target["_GO"]
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
if useBeamSearch > 1:
decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=useBeamSearch)
decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell, decoder_embedding, tokens_go, w2i_target["_EOS"], decoder_initial_state , beam_width=useBeamSearch, output_layer=tf.layers.Dense(config.target_vocab_size))
else:
decoder_initial_state = encoder_state
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=tf.layers.Dense(config.target_vocab_size))
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))
3.4 小结
如果仅仅使用Beam Search, decoder_initial_state需要做K倍扩展。
如果和AttentionWrapper搭配使用,还需要把encoder_outputs 和 input_sequence_len都用tile_batch扩展一遍。
四、Decode
4.1 基本使用
dynamic_decode()
returns:
(final_outputs, final_state, final_sequence_lengths)
其中 final_outputs包含两部分:
final_outputs.sample_id : 预测的token id, 作为模型的预测输出.
final_outputs.rnn_output: softmax后的概率分, 用来计算loss.