这篇文章就简单从源码的角度上分析一下tf.contrib.seq2seq下提供的API,首先来讲这个文件夹下面的几个文件和函数上篇文章中都已经提到而且介绍了他们之间的关系和如何使用,如果对源码不感兴趣就不用看下去了~~
BasicDecoder和dynamic_decode
为了简单起见,从decode的入口dynamic_deocde函数开始分析:
dynamic_decode(
decoder,
output_time_major=False,
impute_finished=False,
maximum_iterations=None,
parallel_iterations=32,
swap_memory=False,
scope=None
)
decoder: BasicDecoder、BeamSearchDecoder或者自己定义的decoder类对象
output_time_major: 见RNN,为真时step*batch_size*...,为假时batch_size*step*...
impute_finished: Boolean,为真时会拷贝最后一个时刻的状态并将输出置零,程序运行更稳定,使最终状态和输出具有正确的值,在反向传播时忽略最后一个完成步。但是会降低程序运行速度。
maximum_iterations: 最大解码步数,一般训练设置为decoder_inputs_length,预测时设置一个想要的最大序列长度即可。程序会在产生<eos>或者到达最大步数处停止。
其实简单来讲dynamic_decode就是先执行decoder的初始化函数,对解码时刻的state等变量进行初始化,然后循环执行decoder的step函数进行多轮解码。如果让我写可能就一个for循环,但是源码里面比较复杂,因为会涉及到很多条件判断等,以保证程序正常运行和报错。所以我们直接来看主体程序部分,也是一个control_flow_ops.while_loop循环,正好借机了解一下这个函数的使用方法:
while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)
cond是循环的条件,body是循环执行的主体,这两个都是函数。loop_vars是要用到的变量,cond和body的参数相同且都是loop_vars。但一般cond只用到个别参数用来判断循环是否结束,大部分参数都是body中用到。parallel_iterations是并行执行循环的个数。看下面cond函数其实就是看finished变量是否已经全部变为0,而body函数也就是执行了decoder.step(time, inputs, state)
这句代码之后一系列的赋值和判断。
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths):
return math_ops.logical_not(math_ops.reduce_all(finished))
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
#======1,调用step函数得到下一时刻的输出、状态、并得到下一时刻输入(由helper得到)和是否完成变量decoder_finished
(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state)
#======2,根据decoder_finished和time是否已经大于maximum_iterations综合判断解码是否结束
next_finished = math_ops.logical_or(decoder_finished, finished)
if maximum_iterations is not None:
next_finished = math_ops.logical_or(
next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
math_ops.logical_and(math_ops.logical_not(finished), next_finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
nest.assert_same_structure(inputs, next_inputs),
##======3,如果设置了impute_finished为真,在程序结束时将next_outputs置为零,不让其进行反向传播。并对decoder_state进行拷贝得到下一时刻状态。所以这里如果设置为true,会浪费一些时间
if impute_finished:
emit = nest.map_structure(lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs)
else:
emit = next_outputs
# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tensor_array_ops.TensorArray):
pass_through = True
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else array_ops.where(finished, cur, new)
if impute_finished:
next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
else:
next_state = decoder_state
#=====4,返回结果。
outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
#调用上面定义的cond和body进行循环解码
res = control_flow_ops.while_loop(condition, body,
loop_vars=[initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths, ],
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
看完上面代码,就会想知道decoder.step()函数究竟做了哪些工作。其实你可以把它理解为RNNCell.cell滚动了一次。只不过考虑到解码,会在此基础上添加一些诸如使用helper得到输出答案,并将其转换为下一时刻输入等操作。如下所示:
def step(self, time, inputs, state, name=None):
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
#如果设置了output层,将cell的输出进行映射
cell_outputs = self._output_layer(cell_outputs)
#根据输出结果&#x