前言
tf.contrib.seq2seq.dynamic_decode源码分析本文衔接上文。
首先tf.contrib.seq2seq.dynamic_decode主要作用是接收一个Decoder类,然后依据Encoder进行解码,实现序列的生成(映射)。其中,这个函数主要的一个思想是一步一步地调用Decoder的step函数(该函数接收当前的输入和隐层状态会生成下一个词),实现最后的一句话的生成。该函数类似tf.nn.dynamic_rnn。
该函数用到的Decoder类就是今天所要解析的类。
源码解析
class BasicDecoder(decoder.Decoder):
"""Basic sampling decoder."""
def __init__(self, cell, helper, initial_state, output_layer=None):
"""Initialize BasicDecoder.
Args:
cell: RNN实例
helper: Helper类,用于训练和推理
initial_state: 初始状态
output_layer: 输出层
Raises:
TypeError: 如果`cell`, `helper` or `output_layer`没有正确的类型
"""
rnn_cell_impl.assert_like_rnncell("cell", cell)
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))
if (output_layer is not None
and not isinstance(output_layer, layers_base.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
self._helper = helper
self._initial_state = initial_state
self._output_layer = output_layer
BasicDecoder是继承于Decoder类,这个类只是个抽象类,定义了几个抽象方法。首先这个cell, helper, initial_state, output_layer这几个参数,cell一般就是个RNN(及其衍生类,比如LSTM)实例,initial_state一般是用Encoder的最后一个隐层状态,也就是标准Seq2seq的做法,output_layer是输出层,很自然。
那helper是啥可能有些抽象。这里简单的说就是文本生成分为两个阶段,一个是训练,一个是推理。那么我们希望得到训练的output(输出层的输出),推理的采样样本。而这里也是采用了一个策略模式(设计模式的内容,不懂的可以看看),把Helper分为两大类,一种是TrainingHelper,一种是InferenceHelper。
@property
def batch_size(self):
return self._helper.batch_size
构造getter。
def _rnn_output_size(self):
size = self._cell.output_size
if self._output_layer is None:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = nest.map_structure(
lambda s: tensor_shape.TensorShape([None]).concatenate(s),
size)
layer_output_shape = self._output_layer.compute_output_shape(
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
这里也是判断是否给于输出层,如果有的话,返回全连接层的之后的输出大小。
@property
def output_size(self):
# Return the cell output and the id
return BasicDecoderOutput(
rnn_output=self._rnn_output_size(),
sample_id=self._helper.sample_ids_shape)
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
最重要的一步,这步其实类似LSTM的call方法,接受前一个隐层状态,当前的输入,返回一个输出状态,下一个状态,下一个输出,和finished。
总结
Decoder类有点类似RNN的call方法,接受前一个隐含状态以及当前时刻的输入返回当前的隐含状态和输出。