TensorFlow 源码解读
点击获取 BasicDecoder 源码
该类的作用:采样产生下一个输入
源码中总共实现了两个类:
- BasicDecoderOutput
- BasicDecoder
其中,BasicDecoderOutPut 类在源码中没有实现。
BasicDecoder 类继承了 decoder 中的 Decoder 类,下面看看 BasicDecoder 类中的代码。
首先是一个初始化函数:
def __init__(self, cell, helper, initial_state, output_layer=None):
"""初始化 BasicDecoder.
参数:
cell: 一个 `RNNCell` 实例.
helper: 一个 `Helper` 实例.
initial_state: 一个 (可能组成一个tulpe)tensors 和 TensorArrays.
RNNCell 的初始状态.
output_layer: (可选) 一个 `tf.layers.Layer` 实例, 例如:`tf.layers.Dense`. 应用于RNN 输出层之前的可选层,用于存储结果或者采样.
Raises:
TypeError: 如果 `cell`, `helper` 或 `output_layer` 的类型不正确.
"""
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))
if (output_layer is not None
and not isinstance(output_layer, layers_base.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
self._helper = helper
self._initial_state = initial_state
self._output_layer = output_layer
该函数是 BasicDecoder 类的初始化类,初始化参数有 cell, helper, initial_state, output_layer,具体的含义已经在代码中解释。之后,连续的3个if语句,判断传进来的参数类型是否正确。确保类型正确后,将其赋值给内部的参数。
@property
def batch_size(self):
return self._helper.batch_size
def _rnn_output_size(self):
size = self._cell.output_size
if self._output_layer is None:
return size
else:
# 为了使用 layer 中的 compute_output_shape 函数, 我们需要将 RNNCell 的 output_size 记录装化为一个未知的大小
# 之后,我们将它传入 layer 中的 compute_output_shape 函数,读出除第一维外的所有数据,得到在上面应用了 layer 的 rnn 输出大小(output_size)
output_shape_with_unknown_batch = nest.map_structure(
lambda s: tensor_shape.TensorShape([None]).concatenate(s),
size)
layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
def output_size(self):
# Return the cell output and the id
return BasicDecoderOutput(
rnn_output=self._rnn_output_size(),
sample_id=self._helper.sample_ids_shape)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and the sample_ids_dtype from the helper.
dtype = nest.flatten(self._initial_state)[0].dtype
return BasicDecoderOutput(
nest.map_structure(lambda _: dtype, self._rnn_output_size()),
self._helper.sample_ids_dtype)
采用装饰器 @property 的作用是可以通过调用属性的方法来实现代用函数,例如:
decoder_layer = tf.contrib.seq2seq.BasicDecoder(...)
通过 decoder_layer.batch_size
就可以调用函数
def initialize(self, name=None):
"""初始化 decoder.
参数:
name: 对于任何已穿件操做的命名
返回:
`(finished, first_inputs, initial_state)`.
"""
return self._helper.initialize() + (self._initial_state,)
def step(self, time, inputs, state, name=None):
"""实现 decoding 步.
参数:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
# 如果存在output_layer 就调用它,并更新对应输出
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
# 通过helper得到sample_ids,并利用sample_ids 进一步得到下一步的输入
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
首先,调用 _cell,利用传入的cell计算得到相应的输出和状态,如果 output_layer 不为 None ,则进一步计算 并更新 cell_outputs。然后通过掉用 helper 的 sample 函数得到 sample_ids,最后,调用 helper 的 next_inputs 函数完成相关的计算。