TensorFlow 源码解读之BasicDecoder

TensorFlow 源码解读

点击获取 BasicDecoder 源码


该类的作用:采样产生下一个输入

源码中总共实现了两个类:
- BasicDecoderOutput
- BasicDecoder
其中,BasicDecoderOutPut 类在源码中没有实现。
BasicDecoder 类继承了 decoder 中的 Decoder 类,下面看看 BasicDecoder 类中的代码。
首先是一个初始化函数:

  def __init__(self, cell, helper, initial_state, output_layer=None):
    """初始化 BasicDecoder.
    参数:
      cell: 一个 `RNNCell` 实例.
      helper: 一个 `Helper` 实例.
      initial_state: 一个 (可能组成一个tulpe)tensors 和 TensorArrays.
        RNNCell 的初始状态.
      output_layer: (可选) 一个 `tf.layers.Layer` 实例, 例如:`tf.layers.Dense`. 应用于RNN 输出层之前的可选层,用于存储结果或者采样.
    Raises:
      TypeError: 如果 `cell`, `helper` 或 `output_layer` 的类型不正确.
    """
    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
      raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
    if not isinstance(helper, helper_py.Helper):
      raise TypeError("helper must be a Helper, received: %s" % type(helper))
    if (output_layer is not None
        and not isinstance(output_layer, layers_base.Layer)):
      raise TypeError(
          "output_layer must be a Layer, received: %s" % type(output_layer))
    self._cell = cell
    self._helper = helper
    self._initial_state = initial_state
    self._output_layer = output_layer

该函数是 BasicDecoder 类的初始化类,初始化参数有 cell, helper, initial_state, output_layer,具体的含义已经在代码中解释。之后,连续的3个if语句,判断传进来的参数类型是否正确。确保类型正确后,将其赋值给内部的参数。


  @property
  def batch_size(self):
    return self._helper.batch_size

  def _rnn_output_size(self):
    size = self._cell.output_size
    if self._output_layer is None:
      return size
    else:
      # 为了使用 layer 中的 compute_output_shape 函数, 我们需要将 RNNCell 的 output_size 记录装化为一个未知的大小
      # 之后,我们将它传入 layer 中的 compute_output_shape 函数,读出除第一维外的所有数据,得到在上面应用了 layer 的 rnn 输出大小(output_size)
      output_shape_with_unknown_batch = nest.map_structure(
          lambda s: tensor_shape.TensorShape([None]).concatenate(s),
          size)
      layer_output_shape = self._output_layer._compute_output_shape(  # pylint: disable=protected-access
          output_shape_with_unknown_batch)
      return nest.map_structure(lambda s: s[1:], layer_output_shape)

  @property
  def output_size(self):
    # Return the cell output and the id
    return BasicDecoderOutput(
        rnn_output=self._rnn_output_size(),
        sample_id=self._helper.sample_ids_shape)

  @property
  def output_dtype(self):
    # Assume the dtype of the cell is the output_size structure
    # containing the input_state's first component's dtype.
    # Return that structure and the sample_ids_dtype from the helper.
    dtype = nest.flatten(self._initial_state)[0].dtype
    return BasicDecoderOutput(
        nest.map_structure(lambda _: dtype, self._rnn_output_size()),
        self._helper.sample_ids_dtype)

采用装饰器 @property 的作用是可以通过调用属性的方法来实现代用函数,例如:
decoder_layer = tf.contrib.seq2seq.BasicDecoder(...)
通过 decoder_layer.batch_size就可以调用函数


def initialize(self, name=None):
    """初始化 decoder.
    参数:
      name: 对于任何已穿件操做的命名
    返回:
      `(finished, first_inputs, initial_state)`.
    """
    return self._helper.initialize() + (self._initial_state,)

  def step(self, time, inputs, state, name=None):
    """实现 decoding 步.
    参数:
      time: scalar `int32` tensor.
      inputs: A (structure of) input tensors.
      state: A (structure of) state tensors and TensorArrays.
      name: Name scope for any created operations.
    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
    with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
      cell_outputs, cell_state = self._cell(inputs, state)
      # 如果存在output_layer 就调用它,并更新对应输出
      if self._output_layer is not None:
        cell_outputs = self._output_layer(cell_outputs)
        # 通过helper得到sample_ids,并利用sample_ids 进一步得到下一步的输入
      sample_ids = self._helper.sample(
          time=time, outputs=cell_outputs, state=cell_state)
      (finished, next_inputs, next_state) = self._helper.next_inputs(
          time=time,
          outputs=cell_outputs,
          state=cell_state,
          sample_ids=sample_ids)
    outputs = BasicDecoderOutput(cell_outputs, sample_ids)

首先,调用 _cell,利用传入的cell计算得到相应的输出和状态,如果 output_layer 不为 None ,则进一步计算 并更新 cell_outputs。然后通过掉用 helper 的 sample 函数得到 sample_ids,最后,调用 helper 的 next_inputs 函数完成相关的计算。


  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值