Tensorflow Seq2Seq使用

一、Helper类

1.1 简单介绍

  目前Seq2Seq包含的Helper主要有TrainingHelper和GreedyEmbeddingHelper.

  • TrainingHelper:如名字所说, 适用在训练阶段,功能是: Teacher Forcing.
  • GreddyEmbeddingHelper: 适用在测试阶段。

1.2 参数结构

TrainingHelper

tf.contrib.seq2seq.TrainingHelper(inputs,       # (batch_size, seq_len, dim)
								  sequence_length,         # (batch_size, )
								  time_major=False,  
								  name=None)

GreedyEmbeddingHelper

tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding,    # 用来做embedding_lookup, 和vocab对应
										 start_tokens,    
										 end_tokens)

以下是参数的意思:

  • embedding: 用来做embedding_lookup, 对应vocab.
  • start_tokens: [batch_size, ]
  • end_tokens: int32 scalar 注意是标量.

1.3 基本使用

# [go]
tokens_go = tf.ones([config.batch_size], dtype=tf.int32) * w2i_target["_GO"]
# Embedding
decoder_embedding = tf.Variable(tf.random_uniform([config.target_vocab_size, config.embedding_dim]), dtype=tf.float32, name='decoder_embedding')

# 建立decoder_cell 
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)

# 创建helper
if useTeacherForcing:
	# 训练阶段的 TeacherForcing.   
	decoder_inputs = tf.concat([tf.reshape(tokens_go,[-1,1]), self.seq_targets[:,:-1]], 1)
	helper =tf.contrib.seq2seq.TrainingHelper(tf.nn.embedding_lookup(decoder_embedding, decoder_inputs), self.seq_targets_length)
else:
	# 推理阶段
	helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embedding, tokens_go, w2i_target["_EOS"])

# 使用BasicDecoder封装Cell. 
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=tf.layers.Dense(config.target_vocab_size))
# 使用dynamic_decoder数据结果. 
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))

二、Attention使用

2.1 基本使用

encoder_state = ....      # seq2seq最后部分的state. 
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)

if useAttention: 
	attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=config.hidden_dim, memory=encoder_outputs, memory_sequence_length=self.seq_inputs_length)
	decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism)
	# 初始化状态, 先zero(), 再clone(). 
	decoder_initial_state = decoder_cell.zero_state(batch_size=config.batch_size, dtype=tf.float32)
	decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=tf.layers.Dense(config.target_vocab_size))
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))

2.2 Attention实现

def attn(self, hidden, encoder_outputs):
	# hidden: B * D
	# encoder_outputs: B * S * D
	attn_weights = tf.matmul(encoder_outputs, tf.expand_dims(hidden, 2))
	# attn_weights: B * S * 1
	attn_weights = tf.nn.softmax(attn_weights, axis=1)
	context = tf.squeeze(tf.matmul(tf.transpose(encoder_outputs, [0,2,1]), attn_weights))
	# context: B * D
	return context

三、Beam Search

3.1 简单介绍

  Beam Search,只在test阶段使用。
  如果使用了Beam Search, 在每个时刻会选择top K的单词作为这个时刻的输出,逐一作为下一时刻的输入参与下一时刻的预测,然后再从这K*L (L为词表大小) 个结果中选top K作为下个时刻的输出。

3.2 参数结构

tf.contrib.seq2seq.BeamSearchDecoder(
   							 cell, embedding, start_tokens, end_token, 
   							 initial_state, beam_width, output_layer=None, 
    						 length_penalty_weight=0.0, coverage_penalty_weight=0.0, reorder_tensor_arrays=True)

以下是参数的解释

  • cell: 一个RNNCell 实例
  • embedding: 用来embedding_lookup.
  • start_tokens: [batch_size, ]
  • end_token: int32 scalar标量.
  • length_penalty_weight: 惩罚长句
  • coverage_penalty_weight: 惩罚重复
  • reorder_tensor_arrays: 如果设置为True, cell state将会根据Beam Search重新排序.

  如果使用了AttentionWrapper, converage_penalty_weight是建议去使用的.它鼓励decoder去覆盖所有的inputs.

3.3 基本使用

  以下是使用了AttentionWrapper、BeamSearch的代码。

with tf.variable_scope('decoder'):
	beam_width = 5
	memory = encoder_outputs 
	
	if mode != 'infer':
		bs = batch_size
	else:
		memory = tf.contrib.seq2seq.tile_batch(memory, beam_width)
		encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
        bs = batch_size * beam_width
		
		input_seq_lens = tf.contrib.seq2seq.tile_batch(input_seq_lens, beam_width)
	
	# 注意力机制
	attention = tf.contrib.seq2seq.LuongAttention(hidden_size, memory, input_seq_lens, scale=True)
	rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
	rnn_cell  = tf.contrib.seq2seq.AttentionWrapper(rnn_cell, attention, hidden_size, name='attention')
	decoder_initial_state = cell.zero_state(bs, tf.float32).clone(cell_state=encoder_state)
	
	# 输出层
	with tf.variable_scope('projected'):
        output_layer = tf.layers.Dense(len(word2id_en), use_bias=False, kernel_initializer=k_initializer)

	
	if mode == 'infer':
		start = tf.ones([batch_size], dtype=tf.int32) * w2i_target["_GO"]   # 注意这里:是batch_size, 而不是bs. 
		decoder = tf.contrib.seq2seq.BeamSearchDecoder(rnn_cell, decoder_embedding, start, word2id['</s>'], decoder_initial_state, beam_width, output_layer)
		outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=2*tf.reduce_max(input_seq_lens))
		sample_ids = outputs.predicted_ids  # ? 这个是...??
	else:
		helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding, tgt_seq_lens)
		decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, decoder_initial_state, output_layer)
		output, final_context_state, _ = tf.contrib.seq2seq.dynamic_decoder(decoder, tgt_seq_lens)
		

  以下是仅仅使用BeamSearch的代码:

tokens_go = tf.ones([config.batch_size], dtype=tf.int32) * w2i_target["_GO"]
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)

if useBeamSearch > 1:
	decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=useBeamSearch)	
	decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell, decoder_embedding, tokens_go, w2i_target["_EOS"],  decoder_initial_state , beam_width=useBeamSearch, output_layer=tf.layers.Dense(config.target_vocab_size))
else:
	decoder_initial_state = encoder_state
	decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=tf.layers.Dense(config.target_vocab_size))
			
decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))

3.4 小结

如果仅仅使用Beam Search, decoder_initial_state需要做K倍扩展。
如果和AttentionWrapper搭配使用,还需要把encoder_outputs 和 input_sequence_len都用tile_batch扩展一遍。

四、Decode

4.1 基本使用

dynamic_decode()

returns:

(final_outputs, final_state, final_sequence_lengths)

其中 final_outputs包含两部分: 
final_outputs.sample_id :  预测的token id, 作为模型的预测输出. 
final_outputs.rnn_output:  softmax后的概率分, 用来计算loss. 
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值